def __init__( self, model_dir, tpu, tpu_job_name=None, tpu_zone=None, gcp_project=None, tpu_topology="v2-8", model_parallelism=8, batch_size=("sequences_per_batch", 1), sequence_length=None, model_type="bitransformer", layout_rules="ensemble:ensemble,batch:batch,d_ff:model,heads:model,vocab:model,experts:batch", mesh_shape=None, mesh_devices=None, autostack=True, learning_rate_schedule=None, keep_checkpoint_max=None, save_checkpoints_steps=5000, optimizer=None, predict_fn=None, variable_filter=None, ensemble_inputs=None, iterations_per_loop=100): """Constructor for MtfModel class. Args: model_dir: str, directory to save the model. tpu: str, the TPU address to use. tpu_job_name: str, name of the TPU worker binary. tpu_zone: str, GCE zone where the Cloud TPU is located gcp_project: str, project name for the Cloud TPU-enabled project. tpu_topology: str, e.g. "2x2" or "v2-8". model_parallelism: integer, the number of cores per model replica. batch_size: An integer or a (method, value) pair to pass to compute_batch_size(). Note that this is the global batch size and not the per-shard batch size. sequence_length: an integer or a dict from feature-key to integer the (packed) sequence length, e.g. {"inputs": 512, "targets": 128} model_type: str, a model type from mesh tf models. layout_rules: an input to mtf.convert_to_layout_rules() mesh_shape: an mtf.Shape or string (e.g., "model:2,batch:4") specifying how the data/model should be parallelized. If None (default), the mesh shape will be constructed using the supplied `tpu_topology` and `model_parallelism` arguments. mesh_devices: a list of strings, the device names to use for each mesh slice. Only required for GPU. autostack: boolean, internally combine variables. learning_rate_schedule: an optional function taking the scalar name argument `step` and the numeric argument `total_train_steps` and return the scalar learning rate. keep_checkpoint_max: an integer, maximum number of checkpoints to keep. save_checkpoints_steps: an integer, steps per checkpoint. optimizer: a class extending optimize.Optimizer, required for training. predict_fn: an optional function that can be used to override the default transformer prediction behavior. Must return a tensor of shape [batch_dim, length_dim] that will be the prediction for each example. Must accept the following arguments: - model: a Unitransformer or Bitransformer - features: a dict representing an example. Every value will be an mtf.Tensor with shape [batch_dim, length_dim]. - variable_dtype: an mtf.VariableDType variable_filter: a str, a variable will only be trained if its name matches this regex. If None (default), train all trainable variables. ensemble_inputs: an integer, see `train_model` docstring for details. iterations_per_loop: integer, steps per train loop """ mesh_shape = mesh_shape or ( utils.tpu_mesh_shape(tpu_topology, model_parallelism) if tpu else "") sequence_length = sequence_length or {"inputs": 512, "targets": 512} if isinstance(sequence_length, int): sequence_length = {"inputs": sequence_length, "targets": sequence_length} self._learning_rate_schedule = ( learning_rate_schedule or learning_rate_schedules.learning_rate_schedule_noam) self._optimizer = optimizer or optimize.AdafactorOptimizer self._sequence_length = sequence_length self._model_dir = model_dir self._model_type = model_type self._ensemble_inputs = ensemble_inputs self._layout_rules = mtf.convert_to_layout_rules(layout_rules) self._mesh_shape = mtf.convert_to_shape(mesh_shape) self._mesh_devices = mesh_devices self._autostack = autostack self._keep_checkpoint_max = keep_checkpoint_max self._save_checkpoints_steps = save_checkpoints_steps self._predict_fn = predict_fn self._variable_filter = variable_filter self._ensemble_inputs = ensemble_inputs self._iterations_per_loop = iterations_per_loop self._cluster = tf.distribute.cluster_resolver.TPUClusterResolver( tpu, zone=tpu_zone, project=gcp_project) if tpu else None self._tpu = tpu self._tpu_job_name = tpu_job_name self._estimator = None # Must be called after _sequence_length, _mesh_shape, and _layout_rules are # set. self.batch_size = batch_size
def __init__( self, model_dir, tpu, tpu_job_name=None, tpu_zone=None, gcp_project=None, tpu_topology="2x2", model_parallelism=8, batch_size=("tokens_per_batch", 1024), sequence_length=None, vocabulary=None, model_type="bitransformer", layout_rules="ensemble:ensemble,batch:batch,d_ff:model,heads:model,vocab:model,experts:batch", autostack=True, learning_rate_schedule=None, keep_checkpoint_max=None, save_checkpoints_steps=5000, optimizer=None, predict_fn=None, variable_filter=None, ensemble_inputs=None, iterations_per_loop=100, init_checkpoint=None): """Constructor for MtfModel class. Args: model_dir: string, directory to save the model. tpu: string, the TPU address to use. tpu_job_name: string, name of the TPU worker binary. tpu_zone: string, GCE zone where the Cloud TPU is located gcp_project: string, project name for the Cloud TPU-enabled project. tpu_topology: string, e.g. "2x2". model_parallelism: integer, the number of cores per model replica. batch_size: An integer or a (method, value) pair to pass to compute_batch_size(). Note that this is the global batch size and not the per-shard batch size. sequence_length: an integer or a dict from feature-key to integer the (packed) sequence length, e.g. {"inputs": 512, "targets": 128} vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary, targets_vocabulary) tuple. model_type: string, a model type from mesh tf models. layout_rules: an input to mtf.convert_to_layout_rules() autostack: boolean, internally combine variables. learning_rate_schedule: an optional function taking the scalar name argument `step` and the numeric argument `total_train_steps` and return the scalar learning rate. keep_checkpoint_max: an integer, maximum number of checkpoints to keep. save_checkpoints_steps: an integer, steps per checkpoint. optimizer: a class extending optimize.Optimizer, required for training. predict_fn: an optional function that can be used to override the default transformer prediction behavior. Must return a tensor of shape [batch_dim, length_dim] that will be the prediction for each example. Must accept the following arguments: - model: a Unitransformer or Bitransformer - features: a dict representing an example. Every value will be an mtf.Tensor with shape [batch_dim, length_dim]. - variable_dtype: an mtf.VariableDType variable_filter: a string, a variable will only be trained if its name matches this regex. If None (default), train all trainable variables. ensemble_inputs: an integer, see `train_model` docstring for details. iterations_per_loop: integer, steps per train loop init_checkpoint: a string, if not None the read in varialbes from this checkpoint path when initializing variables. """ mesh_shape = utils.tpu_mesh_shape(tpu_topology, model_parallelism) vocabulary = vocabulary or SentencePieceVocabulary() sequence_length = sequence_length or {"inputs": 512, "targets": 512} if isinstance(sequence_length, int): sequence_length = { "inputs": sequence_length, "targets": sequence_length } if not isinstance(batch_size, int): self._batch_size = utils.compute_batch_size( sequence_length, mesh_shape, layout_rules, batch_size) else: self._batch_size = batch_size learning_rate_schedule = ( learning_rate_schedule or learning_rate_schedules.learning_rate_schedule_noam) optimizer = optimizer or optimize.AdafactorOptimizer self._sequence_length = sequence_length self._vocabulary = vocabulary self._model_dir = model_dir self._init_checkpoint = init_checkpoint self._model_type = model_type self._ensemble_inputs = ensemble_inputs cluster = tf.distribute.cluster_resolver.TPUClusterResolver( tpu if (tpu) else "", zone=tpu_zone, project=gcp_project) self._estimator = utils.get_estimator( model_type=model_type, input_vocab_size=utils.inputs_vocabulary(vocabulary).vocab_size, output_vocab_size=utils.targets_vocabulary(vocabulary).vocab_size, layout_rules=mtf.convert_to_layout_rules(layout_rules), mesh_shape=mtf.convert_to_shape(mesh_shape), model_dir=model_dir, batch_size=self._batch_size, sequence_length=sequence_length, autostack=autostack, learning_rate_schedule=learning_rate_schedule, keep_checkpoint_max=keep_checkpoint_max, save_checkpoints_steps=save_checkpoints_steps, optimizer=optimizer, predict_fn=predict_fn, variable_filter=variable_filter, ensemble_inputs=ensemble_inputs, use_tpu=tpu, tpu_job_name=tpu_job_name, iterations_per_loop=iterations_per_loop, cluster=cluster, init_checkpoint=init_checkpoint)