Example #1
0
 def estimator(self, vocabulary, init_checkpoint=None):
     return utils.get_estimator(
         model_type=self._model_type,
         input_vocab_size=utils.inputs_vocabulary(vocabulary).vocab_size,
         output_vocab_size=utils.targets_vocabulary(vocabulary).vocab_size,
         layout_rules=self._layout_rules,
         mesh_shape=self._mesh_shape,
         model_dir=self._model_dir,
         batch_size=self.batch_size,
         sequence_length=self._sequence_length,
         autostack=self._autostack,
         learning_rate_schedule=self._learning_rate_schedule,
         keep_checkpoint_max=self._keep_checkpoint_max,
         save_checkpoints_steps=self._save_checkpoints_steps,
         optimizer=self._optimizer,
         predict_fn=self._predict_fn,
         variable_filter=self._variable_filter,
         ensemble_inputs=self._ensemble_inputs,
         use_tpu=self._tpu,
         tpu_job_name=self._tpu_job_name,
         iterations_per_loop=self._iterations_per_loop,
         cluster=self._cluster,
         init_checkpoint=init_checkpoint)
Example #2
0
    def __init__(
            self,
            model_dir,
            tpu,
            tpu_job_name,
            tpu_zone,
            gcp_project,
            batch_size=("tokens_per_batch", 1024),
            sequence_length=None,
            vocabulary=None,
            model_type="bitransformer",
            layout_rules="ensemble:ensemble,batch:batch,d_ff:model,heads:model,vocab:model,experts:batch",
            mesh_shape=None,
            autostack=True,
            learning_rate_schedule=None,
            keep_checkpoint_max=None,
            save_checkpoints_steps=5000,
            optimizer=None,
            predict_fn=None,
            variable_filter=None,
            ensemble_inputs=None,
            iterations_per_loop=100,
            init_checkpoint=None):
        """Constructor for MtfModel class.

    Args:
      model_dir: a string, directory to save the model.
      tpu: string, the Cloud TPU to use for training.
      tpu_job_name: string, name of the TPU worker binary.
      tpu_zone: string, GCE zone where the Cloud TPU is located
      gcp_project: string, project name for the Cloud TPU-enabled project.
      batch_size: An integer or a (method, value) pair to pass to
        compute_batch_size(). Note that this is the global batch size and not
        the per-shard batch size.
      sequence_length: an integer or a dict from feature-key to integer
        the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}
      vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
        targets_vocabulary) tuple.
      model_type: string, a model type from mesh tf models.
      layout_rules: an input to mtf.convert_to_layout_rules()
      mesh_shape: a function that returns mtf.shape
      autostack: boolean, internally combine variables.
      learning_rate_schedule: an optional function taking the scalar name
        argument `step` and the numeric argument `total_train_steps` and return
        the scalar learning rate.
      keep_checkpoint_max: an integer, maximum number of checkpoints to keep.
      save_checkpoints_steps: an integer, steps per checkpoint.
      optimizer: a class extending optimize.Optimizer, required for training.
      predict_fn: an optional function that can be used to override the default
        transformer prediction behavior. Must return a tensor of shape
        [batch_dim, length_dim] that will be the prediction for each example.
        Must accept the following arguments:
          - model: a Unitransformer or Bitransformer
          - features: a dict representing an example. Every value will be an
            mtf.Tensor with shape [batch_dim, length_dim].
          - variable_dtype: an mtf.VariableDType
      variable_filter: a string, a variable will only be trained if its name
        matches this regex. If None (default), train all trainable variables.
      ensemble_inputs: an integer, see `train_model` docstring for details.
      iterations_per_loop: integer, steps per train loop
      init_checkpoint: a string, if not None the read in varialbes from this
        checkpoint path when initializing variables.
    """

        mesh_shape = mesh_shape or []
        vocabulary = vocabulary or SentencePieceVocabulary()

        sequence_length = sequence_length or {"inputs": 512, "targets": 512}

        if isinstance(sequence_length, int):
            sequence_length = {
                "inputs": sequence_length,
                "targets": sequence_length
            }

        if not isinstance(batch_size, int):
            self._batch_size = utils.compute_batch_size(
                sequence_length, mesh_shape, layout_rules, batch_size)
        else:
            self._batch_size = batch_size

        learning_rate_schedule = (
            learning_rate_schedule
            or learning_rate_schedules.learning_rate_schedule_noam)

        optimizer = optimizer or optimize.AdafactorOptimizer

        self._sequence_length = sequence_length
        self._vocabulary = vocabulary
        self._model_dir = model_dir
        self._init_checkpoint = init_checkpoint
        self._model_type = model_type
        self._ensemble_inputs = ensemble_inputs

        cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu if (tpu) else "", zone=tpu_zone, project=gcp_project)

        self._estimator = utils.get_estimator(
            model_type=model_type,
            input_vocab_size=utils.inputs_vocabulary(vocabulary).vocab_size,
            output_vocab_size=utils.targets_vocabulary(vocabulary).vocab_size,
            layout_rules=mtf.convert_to_layout_rules(layout_rules),
            mesh_shape=mtf.convert_to_shape(mesh_shape),
            model_dir=model_dir,
            batch_size=self._batch_size,
            sequence_length=sequence_length,
            autostack=autostack,
            learning_rate_schedule=learning_rate_schedule,
            keep_checkpoint_max=keep_checkpoint_max,
            save_checkpoints_steps=save_checkpoints_steps,
            optimizer=optimizer,
            predict_fn=predict_fn,
            variable_filter=variable_filter,
            ensemble_inputs=ensemble_inputs,
            use_tpu=tpu,
            tpu_job_name=tpu_job_name,
            iterations_per_loop=iterations_per_loop,
            cluster=cluster,
            init_checkpoint=init_checkpoint)