def model(self): hparams = self._hparams encoder_layer_stack = layer_stack_from_hparams(hparams, "encoder_") decoder_layer_stack = layer_stack_from_hparams(hparams, "decoder_") encoder = transformer.Unitransformer( layer_stack=encoder_layer_stack, d_model=hparams.d_model, input_vocab_size=self._inputs_vocab_size, output_vocab_size=None, autoregressive=False, max_length=hparams.max_length, name="encoder", layout=hparams.layout, mesh_shape=hparams.mesh_shape, ) decoder = transformer.Unitransformer( layer_stack=decoder_layer_stack, d_model=hparams.d_model, input_vocab_size=self._targets_vocab_size, output_vocab_size=self._targets_vocab_size, autoregressive=True, max_length=hparams.max_length, label_smoothing=hparams.label_smoothing, shared_embedding_and_softmax_weights=( hparams.shared_embedding_and_softmax_weights), z_loss=hparams.z_loss, name="decoder", layout=hparams.layout, mesh_shape=hparams.mesh_shape, ) return transformer.Bitransformer( encoder, decoder, shared_embedding=hparams.shared_embedding)
def model(self): hparams = self._hparams if isinstance(hparams.encoder_layer_stack, transformer.LayerStack): encoder_layer_stack = hparams.encoder_layer_stack else: encoder_layer_stack = hparams.encoder_layer_stack( hparams, "encoder_") if isinstance(hparams.decoder_layer_stack, transformer.LayerStack): decoder_layer_stack = hparams.decoder_layer_stack else: decoder_layer_stack = hparams.decoder_layer_stack( hparams, "decoder_") return transformer.Bitransformer( encoder_layer_stack=encoder_layer_stack, decoder_layer_stack=decoder_layer_stack, encoder_d_model=hparams.d_model, decoder_d_model=hparams.d_model, input_vocab_size=self._inputs_vocab_size, output_vocab_size=self._targets_vocab_size, max_length=hparams.max_length, shared_embedding=hparams.shared_embedding, shared_embedding_and_softmax_weights=( hparams.shared_embedding_and_softmax_weights), label_smoothing=hparams.label_smoothing, z_loss=hparams.z_loss, layout=hparams.layout, mesh_shape=hparams.mesh_shape)
def model(self): hparams = self._hparams if isinstance(hparams.encoder_layer_stack, transformer.LayerStack): encoder_layer_stack = hparams.encoder_layer_stack else: encoder_layer_stack = hparams.encoder_layer_stack(hparams) if isinstance(hparams.decoder_layer_stack, transformer.LayerStack): decoder_layer_stack = hparams.decoder_layer_stack else: decoder_layer_stack = hparams.decoder_layer_stack(hparams) return transformer.Bitransformer( encoder_layer_stack=encoder_layer_stack, decoder_layer_stack=decoder_layer_stack, encoder_d_model=hparams.d_model, decoder_d_model=hparams.d_model, input_vocab_size=self._inputs_vocab_size, output_vocab_size=self._targets_vocab_size, max_length=hparams.max_length, shared_embedding=hparams.shared_embedding, label_smoothing=hparams.label_smoothing)
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config use_tpu = FLAGS.tpu global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) model = transformer.Bitransformer( encoder_layer_stack=layer_stack(include_encdec_attention=False), decoder_layer_stack=layer_stack(include_encdec_attention=True), encoder_d_model=FLAGS.d_model, decoder_d_model=FLAGS.d_model, input_vocab_size=transformer_dataset.padded_vocab_size( transformer_dataset.inputs_vocab_size(FLAGS.dataset)), output_vocab_size=transformer_dataset.padded_vocab_size( transformer_dataset.targets_vocab_size(FLAGS.dataset)), max_length=FLAGS.max_length, shared_embedding=False, shared_embedding_and_softmax_weights=True, label_smoothing=FLAGS.label_smoothing, layout=FLAGS.layout, mesh_shape=FLAGS.mesh_shape) inputs = import_feature(features, mesh, "inputs") # Data-types used for variables and activations # See comments in the FLAGS master_dtype = tf.as_dtype(FLAGS.master_dtype) if FLAGS.slice_dtype: slice_dtype = tf.as_dtype(FLAGS.slice_dtype) elif not FLAGS.tpu or FLAGS.mode == "train": slice_dtype = tf.float32 else: slice_dtype = tf.bfloat16 if FLAGS.activation_dtype: activation_dtype = tf.as_dtype(FLAGS.activation_dtype) else: activation_dtype = tf.bfloat16 if FLAGS.tpu else tf.float32 variable_dtype = mtf.VariableDType(master_dtype=master_dtype, slice_dtype=slice_dtype, activation_dtype=activation_dtype) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: mtf_samples = model.decode( inputs, variable_dtype=variable_dtype, beam_size=FLAGS.beam_size, alpha=FLAGS.alpha, temperature=FLAGS.temperature) mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=FLAGS.autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = { "outputs": outputs } return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) targets = import_feature(features, mesh, "targets") anon_targets = mtf.anonymize(targets) logits, loss = model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=variable_dtype, encoder_sequence_id=import_feature(features, mesh, "inputs_segmentation"), decoder_sequence_id=import_feature( features, mesh, "targets_segmentation"), encoder_position=import_feature(features, mesh, "inputs_position"), decoder_position=import_feature(features, mesh, "targets_position") ) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer() update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=FLAGS.autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print( tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def padded_neg_log_perplexity(logits, labels): weights = tf.to_float(tf.not_equal(labels, 0)) xent = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits) return {"neg_log_perplexity": tf.metrics.mean(-xent, weights)} labels = lowering.export_to_tf_tensor(anon_targets) eval_metrics = (padded_neg_log_perplexity, [tf_logits, labels]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def model(input_vocab_size, output_vocab_size, text2self, num_layers=gin.REQUIRED, d_ff=gin.REQUIRED, d_kv=gin.REQUIRED, d_model=gin.REQUIRED, num_heads=gin.REQUIRED, dropout=gin.REQUIRED, max_length=gin.REQUIRED, length=gin.REQUIRED, label_smoothing=gin.REQUIRED, layout=gin.REQUIRED, mesh_shape=gin.REQUIRED): """Build a simple Transformer model. Args: input_vocab_size: an integer output_vocab_size: an integer text2self: a boolean meaning a language model (True) or encoder/decoder (False) num_layers: integer, number of transformer layers d_ff: integer, size of feed-forward hidden layers d_kv: integer, size of attention keys/values d_model: integer, size of hidden state num_heads: integer, heads per attention layer dropout: float, dropout rate max_length: maximum sequence length (checkpoints depend on this) length: actual sequence length - defaults to max_length label_smoothing: label smoothing layout: a string mesh_shape: a string Returns: a mtf.Unitransformer or mtf.Bitransformer """ # Needed for Gin injection. del length def layer_stack(include_encdec_attention): """Create a LayerStack. TODO(noam): implement a way to configure custom layer stacks using hyperparameters. (as in mtf_transformer2 in the tensor2tensor library). That functionality should go in transformer/model_builder.py Args: include_encdec_attention: a boolean Returns: a transformer.LayerStack """ return model_builder.simple_layer_stack( include_encdec_attention=include_encdec_attention, num_layers=num_layers, d_ff=d_ff, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout) if text2self: return transformer.Unitransformer( layer_stack=layer_stack(include_encdec_attention=False), d_model=d_model, input_vocab_size=input_vocab_size, output_vocab_size=output_vocab_size, autoregressive=True, max_length=max_length, shared_embedding_and_softmax_weights=True, label_smoothing=label_smoothing, layout=layout, mesh_shape=mesh_shape) else: return transformer.Bitransformer( encoder_layer_stack=layer_stack(include_encdec_attention=False), decoder_layer_stack=layer_stack(include_encdec_attention=True), encoder_d_model=d_model, decoder_d_model=d_model, input_vocab_size=input_vocab_size, output_vocab_size=output_vocab_size, max_length=max_length, shared_embedding=False, shared_embedding_and_softmax_weights=True, label_smoothing=label_smoothing, layout=layout, mesh_shape=mesh_shape)