def export(self, export_dir=None, checkpoint_step=-1, beam_size=1, temperature=1.0, keep_top_k=-1, vocabulary=None, eval_with_score=False): """Exports a TensorFlow SavedModel. Args: export_dir: str, a directory in which to export SavedModels. Will use `model_dir` if unspecified. checkpoint_step: int, checkpoint to export. If -1 (default), use the latest checkpoint from the pretrained model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. keep_top_k: integer, a value between 1 and the vocabulary size. When sampling, only pick tokens that are in the k most likely. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. eval_with_score: If True, compute log-likelihood scores of targets. If False, do inference to generate outputs. Returns: The string path to the exported directory. """ if checkpoint_step == -1: checkpoint_step = utils.get_latest_checkpoint_from_dir( self._model_dir) _parse_operative_config(self._model_dir) with gin.unlock_config(): gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) gin.bind_parameter("Bitransformer.decode.sampling_keep_top_k", keep_top_k) if vocabulary is None: vocabulary = utils.get_vocabulary() model_ckpt = "model.ckpt-" + str(checkpoint_step) export_dir = export_dir or self._model_dir estimator = self.estimator(vocabulary, disable_tpu=True, score_in_predict_mode=eval_with_score) return mtf_utils.export_model(estimator, export_dir, vocabulary, self._sequence_length, self._model_type, batch_size=self.batch_size, checkpoint_path=os.path.join( self._model_dir, model_ckpt), eval_with_score=eval_with_score)
def predict(self, input_file, output_file, checkpoint_steps=-1, beam_size=1, temperature=1.0, keep_top_k=-1, vocabulary=None): """Predicts targets from the given inputs. Args: input_file: str, path to a text file containing newline-separated input prompts to predict from. output_file: str, path prefix of output file to write predictions to. Note the checkpoint step will be appended to the given filename. checkpoint_steps: int, list of ints, or None. If an int or list of ints, inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run inference continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. keep_top_k: integer, a value between 1 and the vocabulary size. When sampling, only pick tokens that are in the k most likely. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. """ # TODO(sharannarang) : It would be nice to have a function like # load_checkpoint that loads the model once and then call decode_from_file # multiple times without having to restore the checkpoint weights again. # This would be particularly useful in colab demo. if checkpoint_steps == -1: checkpoint_steps = utils.get_latest_checkpoint_from_dir( self._model_dir) _parse_operative_config(self._model_dir) with gin.unlock_config(): gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) gin.bind_parameter("Bitransformer.decode.sampling_keep_top_k", keep_top_k) gin.bind_parameter("utils.decode_from_file.input_filename", input_file) gin.bind_parameter("utils.decode_from_file.output_filename", output_file) if vocabulary is None: vocabulary = utils.get_vocabulary() mtf_utils.infer_model(self.estimator(vocabulary), vocabulary, self._sequence_length, self.batch_size, self._model_type, self._model_dir, checkpoint_steps)
def get_vocabulary(mixture_or_task_name=None): """Get the appropriate value for the utils.run.vocabulary argument. Args: mixture_or_task_name: string, an identifier for a Mixture or Task in the appropriate registry. Must be specified via gin. Returns: Either a single t5.data.vocabularies.Vocabulary or a tuple of t5.data.vocabularies.Vocabulary for inputs and targets. """ return model_utils.get_vocabulary(mixture_or_task_name)
def run_eval(mixture_or_task_name: str, predict_or_score_fn: PredictOrScoreFnCallable, checkpoint_steps: Iterable[int], dataset_fn: Optional[Callable[ [t5.data.Task, Mapping[str, int], int, str, Optional[bool]], tf.data.Dataset]] = None, summary_dir: Optional[str] = None, split: Optional[str] = "validation", sequence_length: Optional[Mapping[str, int]] = None, batch_size: Optional[int] = None): """Run evaluation on the given mixture or task. Args: mixture_or_task_name: str, the name of the Mixture or Task to evaluate on. Must be pre-registered in the global `TaskRegistry` or `MixtureRegistry.` predict_or_score_fn: function, This function takes in the sequence length, checkpoint step, tasks to evaluate, an eval_dataset_fn, a dict mapping task names to cached examples, a dict mapping task names to datasets, and returns a list of outputs or a list of scores. checkpoint_steps: an iterator with integers for checkpoint steps to evaluate on. dataset_fn: function, This function takes a task and returns the dataset associated with it. If None, the default mesh_eval_dataset_fn is used. summary_dir: str, path to write TensorBoard events file summaries for eval. If None, use model_dir/eval_{split}. split: str, the mixture/task split to evaluate on. sequence_length: an integer or a dict from feature-key to integer the sequence length to pad or truncate to, e.g. {"inputs": 512, "targets": 128}. If None, sequence length is automatically computed during eval. batch_size: integer, used only to check that expected padding matches the targets. If None, the check is skipped. """ vocabulary = model_utils.get_vocabulary(mixture_or_task_name) tasks = t5.data.get_subtasks( t5.data.get_mixture_or_task(mixture_or_task_name)) tasks = seqio.evaluation.get_valid_eval_tasks(tasks, split) if not tasks: logging.info( "All provided tasks have metric_fns=[] or no matching splits; " "eval is not possible.") return if not dataset_fn: def _get_task_eval_dataset(task, sequence_length, split): # TODO(sharannarang): Replace with more general function. eval_datasets = mesh_transformer.mesh_eval_dataset_fn( sequence_length=sequence_length, dataset_split=split, mixture_or_task_name=task.name, ) return eval_datasets[0].dataset_fn() dataset_fn = _get_task_eval_dataset summary_writer = None cached_targets, cached_datasets, max_sequence_length = \ seqio.evaluation.get_targets_and_examples( tasks=tasks, dataset_fn=functools.partial( dataset_fn, split=split, sequence_length=None)) if summary_dir: model_utils.write_targets_and_examples(summary_dir, cached_targets, cached_datasets) if sequence_length is None: logging.info("Setting sequence lengths to %s", max_sequence_length) sequence_length = max_sequence_length elif (sequence_length["inputs"] < max_sequence_length["inputs"] or sequence_length["targets"] < max_sequence_length["targets"]): logging.warning( "Given sequence lengths are insufficient for some evaluation inputs " "or targets. These sequences will be truncated to fit, likely " "leading to sub-optimal results. Consider passing `None` for " "sequence_length to have them be automatically computed.\n Got: %s, " "\n Max Lengths:%s", sequence_length, max_sequence_length) elif (sequence_length["inputs"] > max_sequence_length["inputs"] or sequence_length["targets"] > max_sequence_length["targets"]): logging.warning( "Given sequence lengths are longer than necessary for some " "evaluation inputs or targets, resulting in wasted computation. " "Consider passing `None` for sequence_length to have them be " "automatically computed.\n Got: %s,\n Max Lengths: %s", sequence_length, max_sequence_length) for step in checkpoint_steps: logging.info("Evaluating checkpoint step: %d", step) outputs = predict_or_score_fn(checkpoint_step=step, vocabulary=vocabulary, tasks=tasks, datasets=cached_datasets, sequence_length=sequence_length) for task in tasks: # Extract the portion of decodes corresponding to this dataset dataset = cached_datasets[task.name] dataset_size = len(cached_targets[task.name]) predictions = [ task.postprocess_fn(d, example=ex) for d, ex in zip( outputs[:dataset_size], tfds.as_numpy(dataset)) ] # Remove the used decodes. del outputs[:dataset_size] if summary_dir: predictions_filename = os.path.join( summary_dir, "{}_{}_predictions".format(task.name, step)) model_utils.write_lines_to_file(predictions, predictions_filename) with tf.Graph().as_default(): if summary_dir: summary_writer = summary_writer or tf.summary.FileWriter( summary_dir) for metric_fn in task.metric_fns: if summary_dir: summary = tf.Summary() targets = cached_targets[task.name] metric_result = metric_fn(targets, predictions) for metric_name, metric_value in metric_result.items(): tag = "eval/{}/{}".format(task.name, metric_name) logging.info("%s at step %d: %.3f", tag, step, metric_value) if summary_dir: summary.value.add(tag=tag, simple_value=metric_value) summary_writer.add_summary(summary, step) # pytype: disable=attribute-error if summary_dir: summary_writer.flush() # pytype: disable=attribute-error # Only padding should remain. if batch_size: expected_pad = -sum(len(t) for t in cached_targets.values()) % batch_size if outputs and len(outputs) != expected_pad: raise ValueError("{} padded outputs, {} expected.".format( len(outputs), expected_pad))
def score(self, inputs=None, targets=None, mixture_or_task_name=None, mixture_or_task_split=None, scores_file=None, checkpoint_steps=-1, vocabulary=None): """Computes log-likelihood of target per example in targets. Args: inputs: optional - a string (filename), or a list of strings (inputs) targets: optional - a string (filename), or a list of strings (targets) mixture_or_task_name: optional - a string, the name of the Mixture or Task to score on. Must be pre-registered in the global `TaskRegistry` or `MixtureRegistry.` Cannot be supplied in addition to `inputs` and `targets`. mixture_or_task_split: optional - a string, the split of the Mixture or Task to score on. Must be provided if scoring on a Mixture or Task. scores_file: optional - a string (filename), to write example scores to, one per line. checkpoint_steps: int, list of ints, or None. If an int or list of ints, inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run inference continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. Returns: scores: a list of floating point scores matching the dataset order. targets: a list of scored strings matching the dataset order. """ if bool(inputs or targets) == bool(mixture_or_task_name or mixture_or_task_split): raise ValueError( "Either 'inputs' and 'targets' or " "'mixture_or_task_name' and 'mixture_or_task_split' must be " "specified, but not both.") if checkpoint_steps == -1: checkpoint_steps = utils.get_latest_checkpoint_from_dir( self._model_dir) _parse_operative_config(self._model_dir) with gin.unlock_config(): gin.parse_config(self._gin_bindings) if vocabulary is None: vocabulary = utils.get_vocabulary(mixture_or_task_name) estimator = self.estimator(vocabulary, score_in_predict_mode=True) score_postprocess_fn = functools.partial(mtf_utils.save_scores, scores_filename=scores_file) if mixture_or_task_name: score_dataset_fn = functools.partial( t5.models.mesh_transformer.mesh_eval_dataset_fn, mixture_or_task_name=mixture_or_task_name, ) return mtf_utils.score_from_dataset( estimator=estimator, vocabulary=vocabulary, batch_size=self.batch_size, sequence_length=self._sequence_length, model_dir=self._model_dir, eval_checkpoint_step=checkpoint_steps, dataset_split=mixture_or_task_split, score_dataset_fn=score_dataset_fn, score_postprocess_fn=score_postprocess_fn) else: return mtf_utils.score_from_strings( estimator=estimator, vocabulary=vocabulary, model_type=self._model_type, batch_size=self.batch_size, sequence_length=self._sequence_length, model_dir=self._model_dir, eval_checkpoint_step=checkpoint_steps, inputs=inputs, targets=targets, score_postprocess_fn=score_postprocess_fn)