def mesh_eval_dataset_fn(mixture_name,
                         sequence_length,
                         vocabulary,
                         dataset_split,
                         num_eval_examples=None,
                         use_cached=False):
    """Returns all tf.data.Datasets for evaluation on a given mixture.

  This uses the format required for utils.run's `eval_dataset_fn` argument in
  the Mesh TF transformer standalone.

  Args:
    mixture_name: string, an identifier for a mixture in the MixtureRegistry.
      Must be specified via gin.
    sequence_length: dict mapping feature key to the int length for that feature
      the max sequence length.
    vocabulary: a SentencePieceVocabulary.
    dataset_split: string, which split of the dataset to load.
    num_eval_examples: maximum number of examples per task to use for continuous
      eval. If None, use all examples.
    use_cached: bool, whether to load the cached version of this dataset.

  Returns:
    A list of mesh_tensorflow.transformer.dataset.EvalDataset tuples.
  """
    if not isinstance(vocabulary, SentencePieceVocabulary):
        raise ValueError("vocabulary must be a SentencePieceVocabulary")
    mixture = MixtureRegistry.get(mixture_name)

    def _get_dataset_for_single_task(task):
        """Get a tensorflow.data.Dataset for the provided task."""
        ds = task.get_dataset(sequence_length,
                              split=dataset_split,
                              use_cached=use_cached,
                              shuffle=False)
        ds = transformer_dataset.pack_or_pad(ds,
                                             sequence_length,
                                             pack=False,
                                             feature_keys=task.output_features,
                                             ensure_eos=True)
        if num_eval_examples is not None:
            ds = ds.take(num_eval_examples)
        return ds

    outputs = []
    for task in mixture.tasks:
        if dataset_split not in task.splits:
            logging.info("Task %s has no '%s' split, skipping eval.",
                         task.name, dataset_split)
            continue

        outputs.append(
            transformer_dataset.EvalDataset(
                task.name,
                functools.partial(_get_dataset_for_single_task, task),
                task.postprocess_fn,
                task.metric_fns,
            ))

    return outputs
コード例 #2
0
def mesh_eval_dataset_fn(mixture_or_task_name,
                         sequence_length,
                         dataset_split,
                         vocabulary=None,
                         num_eval_examples=-1,
                         use_cached=False,
                         pack=False,
                         shuffle_eval_examples=False,
                         seed=None):
    """Returns all tf.data.Datasets for evaluation on a given mixture.

  This uses the format required for utils.run's `eval_dataset_fn` argument in
  the Mesh TF transformer standalone.

  Args:
    mixture_or_task_name: string, an identifier for a Mixture or Task in the
      appropriate registry. Must be specified via gin.
    sequence_length: dict mapping feature key to the int length for that feature
      the max sequence length. If set to None, packing and padding will be
      disabled.
    dataset_split: string, which split of the dataset to load.
    vocabulary: unused argument, maintains compatibility with other dataaset_fns
    num_eval_examples: maximum number of examples per task to use for continuous
      eval. If None or less than 0, use all examples.
    use_cached: bool, whether to load the cached version of this dataset.
    pack: a boolean, whether to pack examples. This is useful for perplexity
      evals but should not be used for iterative decoding.
    shuffle_eval_examples: boolean, whether to shuffle eval examples, applied
      only when num_eval_examples is not None. Intended to be able to eval on a
      different eval slice at every iteration.
    seed: tf.int64 scalar tf.Tensor (or None). Used for both the global seed and
      shuffle seed for tf.data

  Returns:
    A list of mesh_tensorflow.transformer.dataset.EvalDataset tuples.
  """
    del vocabulary

    mixture_or_task = t5.data.get_mixture_or_task(mixture_or_task_name)

    def _get_dataset_for_single_task(task, sequence_length):
        """Get a tensorflow.data.Dataset for the provided task."""
        if shuffle_eval_examples and seed is None:
            logging.warning(("shuffle_seed_examples is true but no seed was ",
                             "provided. Using a random seed."))

        ds = task.get_dataset(
            sequence_length,
            split=dataset_split,
            use_cached=use_cached,
            shuffle=shuffle_eval_examples,
            seed=seed,
        )
        eos_keys = set(k for k, f in mixture_or_task.output_features.items()
                       if f.add_eos)
        if sequence_length is None:
            logging.info(
                "Skipping packing/padding for '%s' since sequence length is None.",
                task.name)
        else:
            logging.info("%sing '%s' with sequence lengths: %s",
                         "Pack" if pack else "Padd", task.name,
                         sequence_length)
            ds = transformer_dataset.pack_or_pad(ds,
                                                 sequence_length,
                                                 pack=pack,
                                                 feature_keys=tuple(
                                                     task.output_features),
                                                 ensure_eos=eos_keys)

        if num_eval_examples is not None and num_eval_examples >= 0:
            ds = ds.take(num_eval_examples)

        return ds

    outputs = []

    for task in t5.data.get_subtasks(mixture_or_task):
        if dataset_split not in task.splits:
            logging.info("Task %s has no '%s' split, skipping eval.",
                         task.name, dataset_split)
            continue

        outputs.append(
            transformer_dataset.EvalDataset(
                task.name,
                functools.partial(_get_dataset_for_single_task,
                                  task=task,
                                  sequence_length=sequence_length),
                task.postprocess_fn,
                task.metric_fns,
            ))

    if not outputs:
        logging.warning("No %s data found for %s.", dataset_split,
                        mixture_or_task_name)

    return outputs
コード例 #3
0
def mesh_inference_dataset_fn(mixture_or_task_name,
                              sequence_length,
                              dataset_split,
                              shuffle=False,
                              seed=None,
                              vocabulary=None,
                              num_inference_examples=-1,
                              use_cached=False,
                              priming_sequence_length=None):
    """Returns all tf.data.Datasets for LM inference on a given mixture.

  For Tasks without inputs (such as language modeling), the first
  `priming_sequence_length` tokens in the target are used as the "inputs" for
  inference.

  Args:
    mixture_or_task_name: string, an identifier for a Mixture or Task in the
      appropriate registry. Must be specified via gin.
    sequence_length: dict mapping feature key to the int length for that feature
      the max sequence length. If set to None, packing and padding will be
      disabled.
    dataset_split: string, which split of the dataset to load. NOTE, this
      function does NOT receive the split specified in utils.run. It needs to be
      specified separately.
    shuffle: Whether or not to shuffle dataset.
    seed: tf.int64 scalar tf.Tensor (or None). Used as shuffle seed for tf.data.
    vocabulary: unused argument, maintains compatibility with other dataaset_fns
    num_inference_examples: maximum number of examples per task to do inference
      on. If None or less than 0, use all examples.
    use_cached: bool, whether to load the cached version of this dataset.
      evals but should not be used for iterative decoding.
    priming_sequence_length: If the Task only has "targets", select the first
      this many tokens from each target sequence to use as "inputs". This is
      useful for decoder-only language models where you would like to use a
      portion of the targets as a priming sequence for generation.

  Returns:
    A list of mesh_tensorflow.transformer.dataset.EvalDataset tuples.
  """
    del vocabulary
    mixture_or_task = t5.data.get_mixture_or_task(mixture_or_task_name)

    def _split_targets_for_primed_inference(ex):
        ex["inputs"] = ex["targets"][:priming_sequence_length]
        ex["targets"] = ex["targets"][priming_sequence_length:]
        ex["inputs"] = tf.pad(
            ex["inputs"],
            [[0, priming_sequence_length - tf.shape(ex["inputs"])[0]]],
            "CONSTANT")
        ex["inputs"] = tf.reshape(ex["inputs"],
                                  shape=(priming_sequence_length, ))
        return ex

    def _prepare_for_unprimed_inference(ex):
        ex["inputs"] = tf.constant([], dtype=tf.int64)
        return ex

    def _get_dataset_for_single_task(task, sequence_length):
        """Get a tensorflow.data.Dataset for the provided task."""

        ds = task.get_dataset(sequence_length,
                              split=dataset_split,
                              use_cached=use_cached,
                              shuffle=shuffle,
                              seed=seed)
        if "inputs" not in ds.element_spec:
            if not priming_sequence_length or priming_sequence_length <= 0:
                logging.warning(
                    "Priming sequence length not specified so priming "
                    "with the empty string.")
                ds = ds.map(_prepare_for_unprimed_inference)
            else:
                logging.info(
                    "Using the first %d tokens of each target as input.",
                    priming_sequence_length)
                ds = ds.map(_split_targets_for_primed_inference)
        elif priming_sequence_length is not None:
            raise ValueError(
                "Setting a priming sequence length only makes sense for decoder-only "
                "Tasks, which have `targets` but no `inputs`.")

        eos_keys = set(k for k, f in mixture_or_task.output_features.items()
                       if f.add_eos)

        logging.info("Padding '%s' with sequence lengths: %s", task.name,
                     sequence_length)
        ds = transformer_dataset.pack_or_pad(ds,
                                             sequence_length,
                                             pack=False,
                                             feature_keys=tuple(
                                                 task.output_features),
                                             ensure_eos=eos_keys)

        if num_inference_examples is not None and num_inference_examples >= 0:
            ds = ds.take(num_inference_examples)

        return ds

    outputs = []

    for task in t5.data.get_subtasks(mixture_or_task):
        if dataset_split not in task.splits:
            logging.info("Task %s has no '%s' split, skipping inference.",
                         task.name, dataset_split)
            continue

        outputs.append(
            transformer_dataset.EvalDataset(
                task.name,
                functools.partial(_get_dataset_for_single_task,
                                  task=task,
                                  sequence_length=sequence_length),
                task.postprocess_fn,
                task.metric_fns,
            ))

    if not outputs:
        logging.warning("No %s data found for %s.", dataset_split,
                        mixture_or_task_name)

    return outputs
def mesh_eval_dataset_fn(
    mixture_or_task_name,
    sequence_length,
    vocabulary,
    dataset_split,
    num_eval_examples=None,
    use_cached=False,
    pack=False,
    shuffle_eval_examples=False,
    shuffle_buffer_size=t5.data.SHUFFLE_BUFFER_SIZE):
  """Returns all tf.data.Datasets for evaluation on a given mixture.

  This uses the format required for utils.run's `eval_dataset_fn` argument in
  the Mesh TF transformer standalone.

  Args:
    mixture_or_task_name: string, an identifier for a Mixture or Task in the
      appropriate registry. Must be specified via gin.
    sequence_length: dict mapping feature key to the int length for that feature
      the max sequence length.
    vocabulary: a t5.data.vocabularies.Vocabulary.
    dataset_split: string, which split of the dataset to load.
    num_eval_examples: maximum number of examples per task to use for continuous
      eval. If None, use all examples.
    use_cached: bool, whether to load the cached version of this dataset.
    pack: a boolean, whether to pack examples. This is useful for perplexity
      evals but should not be used for iterative decoding.
    shuffle_eval_examples: boolean, whether to shuffle eval examples, applied
      only when num_eval_examples is not None. Intended to be able to eval on a
      different eval slice at every iteration.
    shuffle_buffer_size: integer - the shuffle buffer size if we shuffle
      eval examples, ideally this should be some large multiple of
      `num_eval_examples` to ensure good mixing and random batches.

  Returns:
    A list of mesh_tensorflow.transformer.dataset.EvalDataset tuples.
  """
  valid_vocabulary(vocabulary)

  mixture_or_task = t5.data.get_mixture_or_task(mixture_or_task_name)

  def _get_dataset_for_single_task(task):
    """Get a tensorflow.data.Dataset for the provided task."""
    ds = task.get_dataset(
        sequence_length, split=dataset_split,
        use_cached=use_cached, shuffle=False
    )
    eos_keys = set(
        k for k, f in mixture_or_task.output_features.items() if f.add_eos)
    ds = transformer_dataset.pack_or_pad(
        ds,
        sequence_length,
        pack=pack,
        feature_keys=tuple(task.output_features),
        ensure_eos=eos_keys)
    ds = maybe_shuffle_and_subsample_dataset(
        ds, num_eval_examples, shuffle_eval_examples, shuffle_buffer_size)
    return ds

  outputs = []

  for task in t5.data.get_subtasks(mixture_or_task):
    if dataset_split not in task.splits:
      logging.info(
          "Task %s has no '%s' split, skipping eval.", task.name, dataset_split
      )
      continue

    outputs.append(
        transformer_dataset.EvalDataset(
            task.name,
            functools.partial(_get_dataset_for_single_task, task),
            task.postprocess_fn,
            task.metric_fns,
        )
    )

  return outputs
コード例 #5
0
def eval_model_ll(estimator,
                  vocabulary,
                  sequence_length,
                  batch_size,
                  dataset_split,
                  model_dir,
                  eval_dataset_fn,
                  eval_summary_dir,
                  eval_checkpoint_step,
                  attribute_bit=True,
                  unsupervised_attribute_transfer_metrics=True,
                  control_code_bool=False):
    """Eval a Mesh-TF model.
    Args:
      estimator: Estimator object, created with the appropriate model_fn.
      vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
        targets_vocabulary) tuple
      sequence_length: a dict from feature-key to integer the (packed)
        sequence length, e.g. {"inputs": 512, "targets": 128}
      batch_size: an integer, global batch size
      dataset_split: a string
      model_dir: a string, directory with the model.
      eval_dataset_fn: A function returning a list of dataset.EvalDataset tuples.
        Must be provided for mode="eval". Should accept the following arguments:
          - sequence_length: an integer or a dict from feature-key to integer
            the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}
          - vocabulary: Vocabulary instance to use for encoding.
          - dataset_split: str, which dataset split to load.
        dataset.EvalDataset tuples are namedtuples with the following fields:
          - name: string, the task name
          - dataset_fn: function which returns a tf.data.Dataset of tokenized and
            padded examples. Must not require any arguments and must include the
            feature keys 'inputs' and 'targets_plaintext'.
          - postprocess_fn: function which converts plaintext targets to values
            that can be processed by a `metric_fn`.
          - list_of_metric_fns: list of metric_name functions with the call signature
            `metric_fn(targets, predictions)` which returns a dict mapping
            submetric names to scalar values. TensorBoard summaries and other tags
            will be written out using the submetric names.
      eval_summary_dir: str, path to write TensorBoard events file summaries for
        eval. If None, use model_dir/eval_{split}.
      eval_checkpoint_step: int, list of ints, or None. If an int or list of ints,
        evaluation or inference will be run on the checkpoint files in `model_dir`
        whose global steps are closest to the global steps provided. If None and
        mode="eval", run eval continuously waiting for new checkpoints via
        `tf.train.checkpoints_iterator`.
    """
    if eval_dataset_fn is None:
        raise ValueError("Must provide eval_dataset_fn through gin for eval.")

    eval_datasets = eval_dataset_fn(
        sequence_length=sequence_length,
        vocabulary=vocabulary,
        dataset_split=dataset_split,
    )

    valid_eval_datasets = []
    for eval_dataset in eval_datasets:
        if not eval_dataset.metric_fns:
            tf.logging.info("Skipping %s because metric_fns is empty",
                            eval_dataset.name)
            continue
        # Convert to EvalDataset tuple in case eval_dataset_fn returns raw tuples
        valid_eval_datasets.append(
            transformer_dataset.EvalDataset(*eval_dataset))
    eval_datasets = valid_eval_datasets

    if not eval_datasets:
        tf.logging.info(
            "All provided EvalDatasets have metric_fns=[]; eval is not possible."
        )
        return

    eval_summary_dir = eval_summary_dir or os.path.join(
        model_dir, "{}_eval".format(dataset_split))
    summary_writer = tf.summary.FileWriter(eval_summary_dir)

    # Pre-load in all of the targets once before entering continuous eval loop
    cached_targets = {}
    cached_examples = {}
    if attribute_bit:
        cached_attributes_origin = {}
    # Need to create a separate graph for loading in plaintext targets
    # or else TF will complain that we modified the graph
    with tf.Graph().as_default():
        for eval_dataset in eval_datasets:
            if eval_dataset.metric_fns:
                ds = eval_dataset.dataset_fn()
                # Create list of postprocessed text targets
                examples = [ex for ex in tfds.as_numpy(ds)]
                targets = [
                    eval_dataset.postprocess_fn(  # pylint:disable=g-complex-comprehension
                        tf.compat.as_text(ex["targets_plaintext"]),
                        example=ex,
                        is_target=True) for ex in examples
                ]

                if attribute_bit:
                    attributes_origin = [
                        str(ex["attribute"][0] - 1) for ex in examples
                    ]

                targets_filename = os.path.join(
                    eval_summary_dir,
                    "{}_targets".format(eval_dataset.name),
                )
                write_lines_to_file(targets, targets_filename)
                cached_targets[eval_dataset.name] = targets
                cached_examples[eval_dataset.name] = examples
                if attribute_bit:
                    cached_attributes_origin[
                        eval_dataset.name] = attributes_origin

    if attribute_bit:
        _INPUT_FEATURES_ll.append("attribute")

    if control_code_bool:  # TODO check if everything is usefull...
        _INPUT_FEATURES_ll.extend([
            "controlcode", "controlcode_position", "controlcode_segmentation",
            "controlcode_subsegmentation", "codeprefixedtargets",
            "codeprefixedtargets_position", "codeprefixedtargets_segmentation",
            "codeprefixedtargets_subsegmentation"
        ])

    def input_fn(params):
        """Eval input function for estimator."""
        del params
        # Concatenate all dataset inputs to only have to do one decode loop
        combined_ds = None
        for eval_dataset in eval_datasets:
            # Only cache targets for those tasks with eval functions provides
            if eval_dataset.metric_fns:
                ds = eval_dataset.dataset_fn()
                # Only pass those variables which will be used for decoding
                ds = ds.map(
                    lambda x:
                    {k: v
                     for k, v in x.items() if k in _INPUT_FEATURES_ll})
                combined_ds = ds if not combined_ds else combined_ds.concatenate(
                    ds)
        combined_ds = combined_ds.batch(batch_size, drop_remainder=False)
        # Pad the final batch.
        combined_ds = transformer_dataset.trim_and_pad_dataset(
            combined_ds, length=batch_size)
        combined_ds = combined_ds.prefetch(tf.data.experimental.AUTOTUNE)
        return combined_ds

    checkpoint_paths = get_checkpoint_iterator(eval_checkpoint_step, model_dir)
    for checkpoint_path in checkpoint_paths:
        tf.logging.info("Checkpoint path %s" % checkpoint_path)
        global_step = int(get_step_from_checkpoint_path(checkpoint_path))
        if global_step == 0:
            continue
        decodes = decode(estimator, input_fn, vocabulary, checkpoint_path)
        for eval_dataset in eval_datasets:
            # Extract the portion of decodes corresponding to this dataset
            examples = cached_examples[eval_dataset.name]
            dataset_size = len(examples)
            predictions = [
                eval_dataset.postprocess_fn(tf.compat.as_text(d), example=ex)
                for d, ex in zip(decodes[:dataset_size], examples)
            ]
            # Remove the used decodes.
            del decodes[:dataset_size]

            global_step = int(get_step_from_checkpoint_path(checkpoint_path))

            predictions_filename = os.path.join(
                eval_summary_dir,
                "{}_{}_predictions".format(eval_dataset.name, global_step),
            )
            write_lines_to_file_ll(predictions, predictions_filename)

            for metric_fn in eval_dataset.metric_fns:
                summary = tf.Summary()
                targets = cached_targets[eval_dataset.name]
                if unsupervised_attribute_transfer_metrics and attribute_bit:
                    attributes_origin = cached_attributes_origin[
                        eval_dataset.name]
                    metric_result = metric_fn(
                        targets,
                        predictions,
                        attributes_origin=attributes_origin)
                else:
                    metric_result = metric_fn(targets, predictions)
                for metric_name, metric_value in metric_result.items():
                    tag = "eval/{}/{}".format(eval_dataset.name, metric_name)
                    tf.logging.info("%s at step %d: %.3f", tag, global_step,
                                    metric_value)
                    summary.value.add(tag=tag, simple_value=metric_value)
                    summary_writer.add_summary(summary, global_step)
            summary_writer.flush()

        # Only padding should remain.
        expected_pad = -sum(len(t)
                            for t in cached_targets.values()) % batch_size
        if len(decodes) != expected_pad:
            raise ValueError("{} padded decodes, {} expected.".format(
                len(decodes), expected_pad))
コード例 #6
0
def mesh_eval_dataset_fn(mixture_or_task_name,
                         sequence_length,
                         vocabulary,
                         dataset_split,
                         num_eval_examples=None,
                         use_cached=False,
                         pack=False):
    """Returns all tf.data.Datasets for evaluation on a given mixture.

  This uses the format required for utils.run's `eval_dataset_fn` argument in
  the Mesh TF transformer standalone.

  Args:
    mixture_or_task_name: string, an identifier for a Mixture or Task in the
      appropriate registry. Must be specified via gin.
    sequence_length: dict mapping feature key to the int length for that feature
      the max sequence length.
    vocabulary: a t5.data.vocabularies.Vocabulary.
    dataset_split: string, which split of the dataset to load.
    num_eval_examples: maximum number of examples per task to use for continuous
      eval. If None, use all examples.
    use_cached: bool, whether to load the cached version of this dataset.
    pack: a boolean, whether to pack examples. This is useful for perplexity
      evals but should not be used for iterative decoding.

  Returns:
    A list of mesh_tensorflow.transformer.dataset.EvalDataset tuples.
  """
    valid_vocabulary(vocabulary)

    mixture_or_task = t5.data.get_mixture_or_task(mixture_or_task_name)

    def _get_dataset_for_single_task(task):
        """Get a tensorflow.data.Dataset for the provided task."""
        ds = task.get_dataset(sequence_length,
                              split=dataset_split,
                              use_cached=use_cached,
                              shuffle=False)
        if any(not f.add_eos for f in task.output_features.values()):
            warnings.warn(
                "pack_or_pad is being called with ensure_eos=True, but EOS is not "
                "being added to all features.")
        ds = transformer_dataset.pack_or_pad(ds,
                                             sequence_length,
                                             pack=pack,
                                             feature_keys=tuple(
                                                 task.output_features),
                                             ensure_eos=True)
        if num_eval_examples is not None:
            ds = ds.take(num_eval_examples)
        return ds

    outputs = []

    for task in t5.data.get_subtasks(mixture_or_task):
        if dataset_split not in task.splits:
            logging.info("Task %s has no '%s' split, skipping eval.",
                         task.name, dataset_split)
            continue

        outputs.append(
            transformer_dataset.EvalDataset(
                task.name,
                functools.partial(_get_dataset_for_single_task, task),
                task.postprocess_fn,
                task.metric_fns,
            ))

    return outputs
コード例 #7
0
ファイル: utils.py プロジェクト: masak1112/mesh
def run(tpu_job_name,
        tpu,
        gcp_project,
        tpu_zone,
        model_dir,
        model_type="bitransformer",
        vocabulary=gin.REQUIRED,
        train_dataset_fn=None,
        eval_dataset_fn=None,
        dataset_split="train",
        autostack=True,
        checkpoint_step=None,
        mode="train",
        iterations_per_loop=100,
        save_checkpoints_steps=1000,
        keep_checkpoint_max=10,
        eval_summary_dir=None,
        batch_size=("tokens_per_replica", 2048),
        train_steps=auto_train_steps,
        sequence_length=gin.REQUIRED,
        mesh_shape=gin.REQUIRED,
        layout_rules=gin.REQUIRED,
        learning_rate_schedule=None,
        optimizer=None,
        predict_fn=None):
  """Run training/eval/inference.

  Args:
    tpu_job_name: string, name of TPU worker binary
    tpu: string, the Cloud TPU to use for training
    gcp_project: string, project name for the Cloud TPU-enabled project
    tpu_zone: string, GCE zone where the Cloud TPU is located in
    model_dir: string, estimator model_dir
    model_type: a string - either "bitransformer", "bi_student_teacher", lm" or
      "aligned"
    vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
      targets_vocabulary) tuple.
    train_dataset_fn: A function returning a tf.data.Dataset. Must be provided
      for mode="train". Should accept the following arguments:
        - batch_size: int, number of entries in each batch.
        - sequence_length: int, length of each packed or padded sequence.
        - vocabulary: Vocabulary instance to use for encoding.
        - dataset_split: str, which dataset split to load.
    eval_dataset_fn: A function returning a list of dataset.EvalDataset tuples.
      Must be provided for mode="eval". Should accept the following arguments:
        - batch_size: int, number of entries in each batch.
        - sequence_length: int, length of each packed or padded sequence.
        - vocabulary: Vocabulary instance to use for encoding.
        - dataset_split: str, which dataset split to load.
      dataset.EvalDataset tuples are namedtuples with the following fields:
        - name: string, the task name
        - dataset_fn: function which returns a tf.data.Dataset of tokenized and
          padded examples. Must not require any arguments and must include the
          feature keys 'inputs' and 'targets_plaintext'.
        - postprocess_fn: function which converts model outputs to evalable str
        - list_of_metric_fns: list of metric functions with the call signature
          `metric_fn(targets, predictions)` which return either a scalar value
          or a dict mapping submetric names to scalar values. TensorBoard
          summaries and other tags will be written out using
          `metric_fn.__name__`.
        - dataset_size: number of entries in the dataset.
        - padded_dataset_size: number of entries in the dataset after padding.
    dataset_split: a string
    autostack: boolean, internally combine variables
    checkpoint_step: int, list of ints, or None. Only used when mode="eval" or
      mode="infer". If an int or list of ints, evaluation or inference will be
      run on the checkpoint files  in `model_dir` whose global steps are closest
      to the global steps provided. If None and mode="eval", run eval
      continuously waiting for new checkpoints via
      `tf.contrib.training.checkpoints_iterator`.
    mode: string, train/eval/infer
    iterations_per_loop: integer, steps per train loop
    save_checkpoints_steps: integer, steps per checkpoint
    keep_checkpoint_max: an integer, keep up to this many checkpoints
    eval_summary_dir: str, path to write TensorBoard events file summaries for
      eval. If None, use model_dir/eval_{split}.
    batch_size: An integer or a (method, value) pair to pass to
      compute_batch_size(). Note that this is the global batch size and not the
      per-shard batch size.
    train_steps: An integer or a function with the same signature as
      auto_train_steps().  Total number of training steps.
    sequence_length: an integer
    mesh_shape: an input to mtf.convert_to_shape()
    layout_rules: an input to mtf.convert_to_layout_rules()
    learning_rate_schedule: an optional function taking the scalar name argument
      `step` and the numeric argument `total_train_steps` and return the scalar
      learning rate
    optimizer: a class extending optimize.Optimizer, required for training
    predict_fn: an optional function that can be used to override the default
      transformer prediction behavior. Must return a tensor of shape [batch_dim,
      length_dim] that will be the prediction for each example. Must accept the
      following arguments:
        - model: a Unitransformer or Bitransformer
        - features: a dict representing an example. Every value will be an
          mtf.Tensor with shape [batch_dim, length_dim].
        - variable_dtype: an mtf.VariableDType
  """
  if not isinstance(batch_size, int):
    batch_size = compute_batch_size(
        sequence_length, mesh_shape, layout_rules, batch_size)

  if not isinstance(train_steps, int):
    train_steps = train_steps(batch_size, sequence_length)

  if callable(learning_rate_schedule):
    learning_rate_schedule = functools.partial(
        learning_rate_schedule, total_train_steps=train_steps)

  tf.logging.info("model_type=%s" % model_type,)
  tf.logging.info("mode=%s" % mode,)
  tf.logging.info("sequence_length=%s" % sequence_length,)
  tf.logging.info("batch_size=%s" % batch_size,)
  tf.logging.info("train_steps=%s" % train_steps,)
  tf.logging.info("mesh_shape=%s" % mesh_shape,)
  tf.logging.info("layout_rules=%s" % layout_rules,)

  if mode == "train" and dataset_split != "train":
    raise ValueError("mode==\"train\" requires dataset_split==\"train\"")

  mesh_shape = mtf.convert_to_shape(mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(layout_rules)

  cluster = tf.contrib.cluster_resolver.TPUClusterResolver(
      tpu if (tpu) else "", zone=tpu_zone, project=gcp_project)

  tf.logging.info(
      "Building TPUConfig with tpu_job_name={}".format(tpu_job_name)
  )
  my_tpu_config = tpu_config.TPUConfig(
      tpu_job_name=tpu_job_name,
      iterations_per_loop=iterations_per_loop,
      num_cores_per_replica=1,
      per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST,
  )

  run_config = tpu_config.RunConfig(
      cluster=cluster,
      model_dir=model_dir,
      tpu_config=my_tpu_config,
      # We use a saver hook, so disable checkpoints here to prevent double
      # saving.
      save_checkpoints_steps=None,
      save_checkpoints_secs=None)

  transformer_model = build_model(
      model_type=model_type,
      input_vocab_size=inputs_vocabulary(vocabulary).vocab_size,
      output_vocab_size=targets_vocabulary(vocabulary).vocab_size,
      layout_rules=layout_rules,
      mesh_shape=mesh_shape)

  model_fn = tpu_estimator_model_fn(
      model_type=model_type,
      transformer_model=transformer_model,
      model_dir=model_dir,
      use_tpu=tpu,
      mesh_shape=mesh_shape,
      layout_rules=layout_rules,
      batch_size=batch_size,
      sequence_length=sequence_length,
      autostack=autostack,
      learning_rate_schedule=learning_rate_schedule,
      keep_checkpoint_max=keep_checkpoint_max,
      save_checkpoints_steps=save_checkpoints_steps,
      optimizer=optimizer,
      predict_fn=predict_fn)

  estimator = tpu_estimator.TPUEstimator(
      model_fn=model_fn,
      config=run_config,
      train_batch_size=batch_size,
      eval_batch_size=batch_size,
      predict_batch_size=batch_size,
      use_tpu=tpu,
      export_to_tpu=False,
      params={})

  if mode == "train":
    if train_dataset_fn is None:
      raise ValueError("Must provide train_dataset_fn through gin for train.")
    def input_fn(params):
      del params
      dataset = train_dataset_fn(batch_size=batch_size,
                                 sequence_length=sequence_length,
                                 vocabulary=vocabulary,
                                 dataset_split=dataset_split)
      return dataset

    estimator.train(input_fn=input_fn, max_steps=train_steps)

  elif mode == "eval":
    if eval_dataset_fn is None:
      raise ValueError("Must provide eval_dataset_fn through gin for eval.")

    eval_datasets = eval_dataset_fn(
        batch_size=batch_size,
        sequence_length=sequence_length,
        vocabulary=vocabulary,
        dataset_split=dataset_split,
    )

    # Pre-load in all of the targets once before entering continuous eval loop
    cached_targets = {}
    # Need to create a separate graph for loading in plaintext targets
    # or else TF will complain that we modified the graph
    with tf.Graph().as_default():
      for eval_dataset in eval_datasets:
        eval_dataset = transformer_dataset.EvalDataset(*eval_dataset)
        # Only cache targets for those tasks with eval functions provides
        if eval_dataset.metric_fns:
          ds = eval_dataset.dataset_fn()
          # De-batch the dataset
          ds = ds.flat_map(tf.data.Dataset.from_tensor_slices)
          ds = tfds.as_numpy(ds)
          targets = [
              eval_dataset.postprocess_fn(d["targets_plaintext"]) for d in ds
          ]
          targets = targets[:eval_dataset.dataset_size]
          cached_targets[eval_dataset.name] = targets

    for checkpoint_path in get_checkpoint_iterator(checkpoint_step, model_dir):
      for eval_dataset in eval_datasets:
        eval_dataset = transformer_dataset.EvalDataset(*eval_dataset)
        if not eval_dataset.metric_fns:
          tf.logging.info(
              "Skipping %s because metric_fns is empty", eval_dataset.name
          )
          continue
        metric_names = [metric.__name__ for metric in eval_dataset.metric_fns]
        tf.logging.info(
            "Evaluating %s on metrics %s", eval_dataset.name, metric_names
        )
        tf.logging.info("on split %s", dataset_split)

        def input_fn(params):
          del params
          ds = eval_dataset.dataset_fn()
          # Only pass those variables which will be used for decoding
          ds = ds.map(
              lambda x: {k: v for k, v in x.items() if k in _INPUT_FEATURES}
          )
          return ds

        decodes = decode(
            estimator,
            input_fn,
            eval_dataset.dataset_size,
            eval_dataset.padded_dataset_size,
            batch_size,
            vocabulary,
            checkpoint_path=checkpoint_path,
        )
        predictions = [eval_dataset.postprocess_fn(d) for d in decodes]
        # TODO(craffel): Log predictions and targets.

        eval_summary_dir = eval_summary_dir or os.path.join(
            model_dir, "{}_eval".format(dataset_split)
        )
        summary_writer = tf.summary.FileWriter(eval_summary_dir)
        global_step = int(get_step_from_checkpoint_path(checkpoint_path))
        for metric_fn in eval_dataset.metric_fns:
          summary = tf.Summary()
          tag = "eval/{}/{}/{}".format(
              eval_dataset.name, dataset_split, metric_fn.__name__
          )
          targets = cached_targets[eval_dataset.name]
          metric_result = metric_fn(targets, predictions)
          if isinstance(metric_result, dict):
            tags = ["{}.{}".format(tag, key) for key in metric_result]
            metric_values = metric_result.values()
          else:
            tags, metric_values = [tag], [metric_result]
          for tag, metric_value in zip(tags, metric_values):
            tf.logging.info(
                "%s at step %d: %.3f", tag, global_step, metric_value
            )
            summary.value.add(tag=tag, simple_value=metric_value)
            summary_writer.add_summary(summary, global_step)
        summary_writer.flush()

  elif mode == "infer":
    for checkpoint_path in get_checkpoint_iterator(checkpoint_step, model_dir):
      decode_from_file(
          estimator,
          vocabulary=vocabulary,
          model_type=model_type,
          batch_size=batch_size,
          sequence_length=sequence_length,
          checkpoint_path=checkpoint_path)
  else:
    raise ValueError(
        "unknown mode %s - must be train/eval/infer" % mode)