Ejemplo n.º 1
0
    def train(self,
              mixture_or_task_name,
              steps,
              init_checkpoint=None,
              split="train"):
        """Train the model on the given Mixture or Task.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to train on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      steps: int, the total number of steps to train for.
      init_checkpoint: a string, if not None then read in variables from this
        checkpoint path when initializing variables. Will only initialize
        variables that appear both in the current graph and the checkpoint.
      split: str, the mixture/task split to train on.
    """
        vocabulary = t5.models.mesh_transformer.get_vocabulary(
            mixture_or_task_name)
        dataset_fn = functools.partial(
            t5.models.mesh_transformer.mesh_train_dataset_fn,
            mixture_or_task_name=mixture_or_task_name,
        )
        mtf_utils.train_model(self.estimator(vocabulary, init_checkpoint),
                              vocabulary,
                              self._sequence_length,
                              self.batch_size,
                              dataset_fn,
                              steps,
                              self._ensemble_inputs,
                              dataset_split=split)
Ejemplo n.º 2
0
    def train(self,
              mixture_or_task_name,
              steps,
              init_checkpoint=None,
              split="train"):
        """Train the model on the given Mixture or Task.
        Args:
          mixture_or_task_name: str, the name of the Mixture or Task to train on.
            Must be pre-registered in the global `TaskRegistry` or
            `MixtureRegistry.`
          steps: int, the total number of steps to train for.
          init_checkpoint: a string, if not None then read in variables from this
            checkpoint path when initializing variables. Will only initialize
            variables that appear both in the current graph and the checkpoint.
        """
        vocabulary = get_mixture_or_task_ll(
            mixture_or_task_name).get_vocabulary()
        dataset_fn = functools.partial(
            mesh_train_dataset_fn_ll,
            mixture_or_task_name=mixture_or_task_name,
            batch_size=self.batch_size,
            ensemble_inputs=self._ensemble_inputs,
            group_by_attribute=self.group_by_attribute)

        # When fine-tuning, we first load the gin config of the pre-trained model. Yet here we might set gin parameters
        # with different values than the gin parameter values from the pre-trained gin config. e.g.
        # t5.data.preprocessors.unsupervised.preprocessors.

        if self.group_by_attribute:
            train_model_ll(self.estimator(vocabulary, init_checkpoint),
                           vocabulary,
                           self._sequence_length,
                           self.batch_size,
                           dataset_fn,
                           steps,
                           self._ensemble_inputs,
                           dataset_split=split)
        else:
            utils.train_model(self.estimator(vocabulary, init_checkpoint),
                              vocabulary,
                              self._sequence_length,
                              self.batch_size,
                              dataset_fn,
                              steps,
                              self._ensemble_inputs,
                              dataset_split=split)
Ejemplo n.º 3
0
 def train(self, mixture_or_task_name, steps):
     dataset_fn = functools.partial(
         mesh_train_dataset_fn, mixture_or_task_name=mixture_or_task_name)
     utils.train_model(self._estimator, self._vocabulary,
                       self._sequence_length, self._batch_size, dataset_fn,
                       steps, self._ensemble_inputs)