def __init__(
      self,
      model_dir,
      tpu,
      tpu_job_name=None,
      tpu_zone=None,
      gcp_project=None,
      tpu_topology="v2-8",
      model_parallelism=8,
      batch_size=("sequences_per_batch", 1),
      sequence_length=None,
      model_type="bitransformer",
      layout_rules="ensemble:ensemble,batch:batch,d_ff:model,heads:model,vocab:model,experts:batch",
      mesh_shape=None,
      mesh_devices=None,
      autostack=True,
      learning_rate_schedule=None,
      keep_checkpoint_max=None,
      save_checkpoints_steps=5000,
      optimizer=None,
      predict_fn=None,
      variable_filter=None,
      ensemble_inputs=None,
      iterations_per_loop=100):
    """Constructor for MtfModel class.

    Args:
      model_dir: str, directory to save the model.
      tpu: str, the TPU address to use.
      tpu_job_name: str, name of the TPU worker binary.
      tpu_zone: str, GCE zone where the Cloud TPU is located
      gcp_project: str, project name for the Cloud TPU-enabled project.
      tpu_topology: str, e.g. "2x2" or "v2-8".
      model_parallelism: integer, the number of cores per model replica.
      batch_size: An integer or a (method, value) pair to pass to
        compute_batch_size(). Note that this is the global batch size and not
        the per-shard batch size.
      sequence_length: an integer or a dict from feature-key to integer
        the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}
      model_type: str, a model type from mesh tf models.
      layout_rules: an input to mtf.convert_to_layout_rules()
      mesh_shape: an mtf.Shape or string (e.g., "model:2,batch:4") specifying
        how the data/model should be parallelized. If None (default), the mesh
        shape will be constructed using the supplied `tpu_topology` and
        `model_parallelism` arguments.
      mesh_devices: a list of strings, the device names to use for each mesh
        slice. Only required for GPU.
      autostack: boolean, internally combine variables.
      learning_rate_schedule: an optional function taking the scalar name
        argument `step` and the numeric argument `total_train_steps` and return
        the scalar learning rate.
      keep_checkpoint_max: an integer, maximum number of checkpoints to keep.
      save_checkpoints_steps: an integer, steps per checkpoint.
      optimizer: a class extending optimize.Optimizer, required for training.
      predict_fn: an optional function that can be used to override the default
        transformer prediction behavior. Must return a tensor of shape
        [batch_dim, length_dim] that will be the prediction for each example.
        Must accept the following arguments:
          - model: a Unitransformer or Bitransformer
          - features: a dict representing an example. Every value will be an
            mtf.Tensor with shape [batch_dim, length_dim].
          - variable_dtype: an mtf.VariableDType
      variable_filter: a str, a variable will only be trained if its name
        matches this regex. If None (default), train all trainable variables.
      ensemble_inputs: an integer, see `train_model` docstring for details.
      iterations_per_loop: integer, steps per train loop
    """
    mesh_shape = mesh_shape or (
        utils.tpu_mesh_shape(tpu_topology, model_parallelism) if tpu else "")

    sequence_length = sequence_length or {"inputs": 512, "targets": 512}

    if isinstance(sequence_length, int):
      sequence_length = {"inputs": sequence_length,
                         "targets": sequence_length}
    self._learning_rate_schedule = (
        learning_rate_schedule or
        learning_rate_schedules.learning_rate_schedule_noam)

    self._optimizer = optimizer or optimize.AdafactorOptimizer

    self._sequence_length = sequence_length
    self._model_dir = model_dir
    self._model_type = model_type
    self._ensemble_inputs = ensemble_inputs

    self._layout_rules = mtf.convert_to_layout_rules(layout_rules)
    self._mesh_shape = mtf.convert_to_shape(mesh_shape)
    self._mesh_devices = mesh_devices

    self._autostack = autostack
    self._keep_checkpoint_max = keep_checkpoint_max
    self._save_checkpoints_steps = save_checkpoints_steps
    self._predict_fn = predict_fn
    self._variable_filter = variable_filter
    self._ensemble_inputs = ensemble_inputs
    self._iterations_per_loop = iterations_per_loop

    self._cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
        tpu, zone=tpu_zone, project=gcp_project) if tpu else None
    self._tpu = tpu
    self._tpu_job_name = tpu_job_name
    self._estimator = None

    # Must be called after _sequence_length, _mesh_shape, and _layout_rules are
    # set.
    self.batch_size = batch_size
Exemplo n.º 2
0
    def __init__(
            self,
            model_dir,
            tpu,
            tpu_job_name=None,
            tpu_zone=None,
            gcp_project=None,
            tpu_topology="2x2",
            model_parallelism=8,
            batch_size=("tokens_per_batch", 1024),
            sequence_length=None,
            vocabulary=None,
            model_type="bitransformer",
            layout_rules="ensemble:ensemble,batch:batch,d_ff:model,heads:model,vocab:model,experts:batch",
            autostack=True,
            learning_rate_schedule=None,
            keep_checkpoint_max=None,
            save_checkpoints_steps=5000,
            optimizer=None,
            predict_fn=None,
            variable_filter=None,
            ensemble_inputs=None,
            iterations_per_loop=100,
            init_checkpoint=None):
        """Constructor for MtfModel class.

    Args:
      model_dir: string, directory to save the model.
      tpu: string, the TPU address to use.
      tpu_job_name: string, name of the TPU worker binary.
      tpu_zone: string, GCE zone where the Cloud TPU is located
      gcp_project: string, project name for the Cloud TPU-enabled project.
      tpu_topology: string, e.g. "2x2".
      model_parallelism: integer, the number of cores per model replica.
      batch_size: An integer or a (method, value) pair to pass to
        compute_batch_size(). Note that this is the global batch size and not
        the per-shard batch size.
      sequence_length: an integer or a dict from feature-key to integer
        the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}
      vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
        targets_vocabulary) tuple.
      model_type: string, a model type from mesh tf models.
      layout_rules: an input to mtf.convert_to_layout_rules()
      autostack: boolean, internally combine variables.
      learning_rate_schedule: an optional function taking the scalar name
        argument `step` and the numeric argument `total_train_steps` and return
        the scalar learning rate.
      keep_checkpoint_max: an integer, maximum number of checkpoints to keep.
      save_checkpoints_steps: an integer, steps per checkpoint.
      optimizer: a class extending optimize.Optimizer, required for training.
      predict_fn: an optional function that can be used to override the default
        transformer prediction behavior. Must return a tensor of shape
        [batch_dim, length_dim] that will be the prediction for each example.
        Must accept the following arguments:
          - model: a Unitransformer or Bitransformer
          - features: a dict representing an example. Every value will be an
            mtf.Tensor with shape [batch_dim, length_dim].
          - variable_dtype: an mtf.VariableDType
      variable_filter: a string, a variable will only be trained if its name
        matches this regex. If None (default), train all trainable variables.
      ensemble_inputs: an integer, see `train_model` docstring for details.
      iterations_per_loop: integer, steps per train loop
      init_checkpoint: a string, if not None the read in varialbes from this
        checkpoint path when initializing variables.
    """

        mesh_shape = utils.tpu_mesh_shape(tpu_topology, model_parallelism)
        vocabulary = vocabulary or SentencePieceVocabulary()

        sequence_length = sequence_length or {"inputs": 512, "targets": 512}

        if isinstance(sequence_length, int):
            sequence_length = {
                "inputs": sequence_length,
                "targets": sequence_length
            }

        if not isinstance(batch_size, int):
            self._batch_size = utils.compute_batch_size(
                sequence_length, mesh_shape, layout_rules, batch_size)
        else:
            self._batch_size = batch_size

        learning_rate_schedule = (
            learning_rate_schedule
            or learning_rate_schedules.learning_rate_schedule_noam)

        optimizer = optimizer or optimize.AdafactorOptimizer

        self._sequence_length = sequence_length
        self._vocabulary = vocabulary
        self._model_dir = model_dir
        self._init_checkpoint = init_checkpoint
        self._model_type = model_type
        self._ensemble_inputs = ensemble_inputs

        cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu if (tpu) else "", zone=tpu_zone, project=gcp_project)

        self._estimator = utils.get_estimator(
            model_type=model_type,
            input_vocab_size=utils.inputs_vocabulary(vocabulary).vocab_size,
            output_vocab_size=utils.targets_vocabulary(vocabulary).vocab_size,
            layout_rules=mtf.convert_to_layout_rules(layout_rules),
            mesh_shape=mtf.convert_to_shape(mesh_shape),
            model_dir=model_dir,
            batch_size=self._batch_size,
            sequence_length=sequence_length,
            autostack=autostack,
            learning_rate_schedule=learning_rate_schedule,
            keep_checkpoint_max=keep_checkpoint_max,
            save_checkpoints_steps=save_checkpoints_steps,
            optimizer=optimizer,
            predict_fn=predict_fn,
            variable_filter=variable_filter,
            ensemble_inputs=ensemble_inputs,
            use_tpu=tpu,
            tpu_job_name=tpu_job_name,
            iterations_per_loop=iterations_per_loop,
            cluster=cluster,
            init_checkpoint=init_checkpoint)