Example #1
0
    def export(self,
               export_dir=None,
               checkpoint_step=-1,
               beam_size=1,
               temperature=1.0,
               keep_top_k=-1,
               vocabulary=None,
               eval_with_score=False):
        """Exports a TensorFlow SavedModel.

    Args:
      export_dir: str, a directory in which to export SavedModels. Will use
        `model_dir` if unspecified.
      checkpoint_step: int, checkpoint to export. If -1 (default), use the
        latest checkpoint from the pretrained model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      keep_top_k: integer, a value between 1 and the vocabulary size. When
        sampling, only pick tokens that are in the k most likely.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.
      eval_with_score: If True, compute log-likelihood scores of targets.
        If False, do inference to generate outputs.

    Returns:
      The string path to the exported directory.
    """
        if checkpoint_step == -1:
            checkpoint_step = utils.get_latest_checkpoint_from_dir(
                self._model_dir)
        _parse_operative_config(self._model_dir)
        with gin.unlock_config():
            gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
            gin.bind_parameter("Bitransformer.decode.temperature", temperature)
            gin.bind_parameter("Bitransformer.decode.sampling_keep_top_k",
                               keep_top_k)

        if vocabulary is None:
            vocabulary = utils.get_vocabulary()
        model_ckpt = "model.ckpt-" + str(checkpoint_step)
        export_dir = export_dir or self._model_dir
        estimator = self.estimator(vocabulary,
                                   disable_tpu=True,
                                   score_in_predict_mode=eval_with_score)
        return mtf_utils.export_model(estimator,
                                      export_dir,
                                      vocabulary,
                                      self._sequence_length,
                                      self._model_type,
                                      batch_size=self.batch_size,
                                      checkpoint_path=os.path.join(
                                          self._model_dir, model_ckpt),
                                      eval_with_score=eval_with_score)
Example #2
0
    def predict(self,
                input_file,
                output_file,
                checkpoint_steps=-1,
                beam_size=1,
                temperature=1.0,
                keep_top_k=-1,
                vocabulary=None):
        """Predicts targets from the given inputs.

    Args:
      input_file: str, path to a text file containing newline-separated input
        prompts to predict from.
      output_file: str, path prefix of output file to write predictions to. Note
        the checkpoint step will be appended to the given filename.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      keep_top_k: integer, a value between 1 and the vocabulary size. When
        sampling, only pick tokens that are in the k most likely.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.
    """
        # TODO(sharannarang) : It would be nice to have a function like
        # load_checkpoint that loads the model once and then call decode_from_file
        # multiple times without having to restore the checkpoint weights again.
        # This would be particularly useful in colab demo.

        if checkpoint_steps == -1:
            checkpoint_steps = utils.get_latest_checkpoint_from_dir(
                self._model_dir)

        _parse_operative_config(self._model_dir)
        with gin.unlock_config():
            gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
            gin.bind_parameter("Bitransformer.decode.temperature", temperature)
            gin.bind_parameter("Bitransformer.decode.sampling_keep_top_k",
                               keep_top_k)
            gin.bind_parameter("utils.decode_from_file.input_filename",
                               input_file)
            gin.bind_parameter("utils.decode_from_file.output_filename",
                               output_file)

        if vocabulary is None:
            vocabulary = utils.get_vocabulary()
        mtf_utils.infer_model(self.estimator(vocabulary), vocabulary,
                              self._sequence_length, self.batch_size,
                              self._model_type, self._model_dir,
                              checkpoint_steps)
Example #3
0
def get_vocabulary(mixture_or_task_name=None):
    """Get the appropriate value for the utils.run.vocabulary argument.

  Args:
    mixture_or_task_name: string, an identifier for a Mixture or Task in the
      appropriate registry. Must be specified via gin.

  Returns:
    Either a single t5.data.vocabularies.Vocabulary or a tuple of
    t5.data.vocabularies.Vocabulary for inputs and targets.
  """
    return model_utils.get_vocabulary(mixture_or_task_name)
def run_eval(mixture_or_task_name: str,
             predict_or_score_fn: PredictOrScoreFnCallable,
             checkpoint_steps: Iterable[int],
             dataset_fn: Optional[Callable[
                 [t5.data.Task, Mapping[str, int], int, str, Optional[bool]],
                 tf.data.Dataset]] = None,
             summary_dir: Optional[str] = None,
             split: Optional[str] = "validation",
             sequence_length: Optional[Mapping[str, int]] = None,
             batch_size: Optional[int] = None):
    """Run evaluation on the given mixture or task.

  Args:
    mixture_or_task_name: str, the name of the Mixture or Task to evaluate
      on. Must be pre-registered in the global `TaskRegistry` or
      `MixtureRegistry.`
    predict_or_score_fn: function, This function takes in the sequence length,
      checkpoint step, tasks to evaluate, an eval_dataset_fn, a dict mapping
      task names to cached examples, a dict mapping task names to datasets,
      and returns a list of outputs or a list of scores.
    checkpoint_steps: an iterator with integers for checkpoint steps to
      evaluate on.
    dataset_fn: function, This function takes a task and returns the dataset
      associated with it. If None, the default mesh_eval_dataset_fn is used.
    summary_dir: str, path to write TensorBoard events file summaries for
      eval. If None, use model_dir/eval_{split}.
    split: str, the mixture/task split to evaluate on.
    sequence_length: an integer or a dict from feature-key to integer
      the sequence length to pad or truncate to,
      e.g. {"inputs": 512, "targets": 128}.
      If None, sequence length is automatically computed during eval.
    batch_size: integer, used only to check that expected padding matches the
      targets. If None, the check is skipped.
  """

    vocabulary = model_utils.get_vocabulary(mixture_or_task_name)

    tasks = t5.data.get_subtasks(
        t5.data.get_mixture_or_task(mixture_or_task_name))
    tasks = seqio.evaluation.get_valid_eval_tasks(tasks, split)

    if not tasks:
        logging.info(
            "All provided tasks have metric_fns=[] or no matching splits; "
            "eval is not possible.")
        return

    if not dataset_fn:

        def _get_task_eval_dataset(task, sequence_length, split):
            # TODO(sharannarang): Replace with more general function.
            eval_datasets = mesh_transformer.mesh_eval_dataset_fn(
                sequence_length=sequence_length,
                dataset_split=split,
                mixture_or_task_name=task.name,
            )

            return eval_datasets[0].dataset_fn()

        dataset_fn = _get_task_eval_dataset

    summary_writer = None

    cached_targets, cached_datasets, max_sequence_length = \
        seqio.evaluation.get_targets_and_examples(
            tasks=tasks,
            dataset_fn=functools.partial(
                dataset_fn, split=split, sequence_length=None))

    if summary_dir:
        model_utils.write_targets_and_examples(summary_dir, cached_targets,
                                               cached_datasets)

    if sequence_length is None:
        logging.info("Setting sequence lengths to %s", max_sequence_length)
        sequence_length = max_sequence_length
    elif (sequence_length["inputs"] < max_sequence_length["inputs"]
          or sequence_length["targets"] < max_sequence_length["targets"]):
        logging.warning(
            "Given sequence lengths are insufficient for some evaluation inputs "
            "or targets. These sequences will be truncated to fit, likely "
            "leading to sub-optimal results. Consider passing `None` for "
            "sequence_length to have them be automatically computed.\n Got: %s, "
            "\n Max Lengths:%s", sequence_length, max_sequence_length)
    elif (sequence_length["inputs"] > max_sequence_length["inputs"]
          or sequence_length["targets"] > max_sequence_length["targets"]):
        logging.warning(
            "Given sequence lengths are longer than necessary for some "
            "evaluation inputs or targets, resulting in wasted computation. "
            "Consider passing `None` for sequence_length to have them be "
            "automatically computed.\n Got: %s,\n Max Lengths: %s",
            sequence_length, max_sequence_length)

    for step in checkpoint_steps:
        logging.info("Evaluating checkpoint step: %d", step)
        outputs = predict_or_score_fn(checkpoint_step=step,
                                      vocabulary=vocabulary,
                                      tasks=tasks,
                                      datasets=cached_datasets,
                                      sequence_length=sequence_length)

        for task in tasks:
            # Extract the portion of decodes corresponding to this dataset
            dataset = cached_datasets[task.name]
            dataset_size = len(cached_targets[task.name])
            predictions = [
                task.postprocess_fn(d, example=ex) for d, ex in zip(
                    outputs[:dataset_size], tfds.as_numpy(dataset))
            ]

            # Remove the used decodes.
            del outputs[:dataset_size]

            if summary_dir:
                predictions_filename = os.path.join(
                    summary_dir, "{}_{}_predictions".format(task.name, step))
                model_utils.write_lines_to_file(predictions,
                                                predictions_filename)

            with tf.Graph().as_default():
                if summary_dir:
                    summary_writer = summary_writer or tf.summary.FileWriter(
                        summary_dir)

                for metric_fn in task.metric_fns:
                    if summary_dir:
                        summary = tf.Summary()
                    targets = cached_targets[task.name]
                    metric_result = metric_fn(targets, predictions)
                    for metric_name, metric_value in metric_result.items():
                        tag = "eval/{}/{}".format(task.name, metric_name)
                        logging.info("%s at step %d: %.3f", tag, step,
                                     metric_value)
                        if summary_dir:
                            summary.value.add(tag=tag,
                                              simple_value=metric_value)
                            summary_writer.add_summary(summary, step)  # pytype: disable=attribute-error
                if summary_dir:
                    summary_writer.flush()  # pytype: disable=attribute-error

        # Only padding should remain.
        if batch_size:
            expected_pad = -sum(len(t)
                                for t in cached_targets.values()) % batch_size
            if outputs and len(outputs) != expected_pad:
                raise ValueError("{} padded outputs, {} expected.".format(
                    len(outputs), expected_pad))
Example #5
0
    def score(self,
              inputs=None,
              targets=None,
              mixture_or_task_name=None,
              mixture_or_task_split=None,
              scores_file=None,
              checkpoint_steps=-1,
              vocabulary=None):
        """Computes log-likelihood of target per example in targets.

    Args:
      inputs: optional - a string (filename), or a list of strings (inputs)
      targets: optional - a string (filename), or a list of strings (targets)
      mixture_or_task_name: optional - a string, the name of the Mixture or Task
        to score on. Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.` Cannot be supplied in addition to `inputs` and
        `targets`.
      mixture_or_task_split: optional - a string, the split of the Mixture or
        Task to score on. Must be provided if scoring on a Mixture or Task.
      scores_file: optional - a string (filename), to write example scores to,
        one per line.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.

    Returns:
      scores: a list of floating point scores matching the dataset order.
      targets: a list of scored strings matching the dataset order.
    """
        if bool(inputs or targets) == bool(mixture_or_task_name
                                           or mixture_or_task_split):
            raise ValueError(
                "Either 'inputs' and 'targets' or "
                "'mixture_or_task_name' and 'mixture_or_task_split' must be "
                "specified, but not both.")

        if checkpoint_steps == -1:
            checkpoint_steps = utils.get_latest_checkpoint_from_dir(
                self._model_dir)

        _parse_operative_config(self._model_dir)
        with gin.unlock_config():
            gin.parse_config(self._gin_bindings)

        if vocabulary is None:
            vocabulary = utils.get_vocabulary(mixture_or_task_name)

        estimator = self.estimator(vocabulary, score_in_predict_mode=True)
        score_postprocess_fn = functools.partial(mtf_utils.save_scores,
                                                 scores_filename=scores_file)

        if mixture_or_task_name:
            score_dataset_fn = functools.partial(
                t5.models.mesh_transformer.mesh_eval_dataset_fn,
                mixture_or_task_name=mixture_or_task_name,
            )
            return mtf_utils.score_from_dataset(
                estimator=estimator,
                vocabulary=vocabulary,
                batch_size=self.batch_size,
                sequence_length=self._sequence_length,
                model_dir=self._model_dir,
                eval_checkpoint_step=checkpoint_steps,
                dataset_split=mixture_or_task_split,
                score_dataset_fn=score_dataset_fn,
                score_postprocess_fn=score_postprocess_fn)
        else:
            return mtf_utils.score_from_strings(
                estimator=estimator,
                vocabulary=vocabulary,
                model_type=self._model_type,
                batch_size=self.batch_size,
                sequence_length=self._sequence_length,
                model_dir=self._model_dir,
                eval_checkpoint_step=checkpoint_steps,
                inputs=inputs,
                targets=targets,
                score_postprocess_fn=score_postprocess_fn)