def create_model(params, is_train): """Creates transformer model.""" with tf.name_scope("model"): if is_train: inputs = tf.keras.layers.Input((None, ), dtype="int64", name="inputs") targets = tf.keras.layers.Input((None, ), dtype="int64", name="targets") internal_model = Transformer(params, name="transformer_v2") logits = internal_model([inputs, targets], training=is_train) vocab_size = params["vocab_size"] label_smoothing = params["label_smoothing"] if params["enable_metrics_in_training"]: logits = metrics.MetricLayer(vocab_size)([logits, targets]) logits = tf.keras.layers.Lambda(lambda x: x, name="logits", dtype=tf.float32)(logits) model = tf.keras.Model([inputs, targets], logits) loss = metrics.transformer_loss(logits, targets, label_smoothing, vocab_size) model.add_loss(loss) return model else: inputs = tf.keras.layers.Input((None, ), dtype="int64", name="inputs") internal_model = Transformer(params, name="transformer_v2") ret = internal_model([inputs], training=is_train) outputs, scores = ret["outputs"], ret["scores"] return tf.keras.Model(inputs, [outputs, scores])
def train_step(self, inputs): """The logic for one training step.""" with tf.GradientTape() as tape: logits, _, _ = self(inputs, mode="train", training=True) targets = models.remove_sos_from_seq(inputs["target_ids"], self.params.pad_token_id) loss = transformer_metrics.transformer_loss( logits, targets, self.params.label_smoothing, self.params.vocab_size) # Scales the loss, which results in using the average loss across all # of the replicas for backprop. scaled_loss = loss / self._num_replicas_in_sync tvars = self.trainable_variables grads = tape.gradient(scaled_loss, tvars) self.optimizer.apply_gradients(list(zip(grads, tvars))) if isinstance(self.optimizer, tf.keras.optimizers.experimental.Optimizer): learning_rate = self.optimizer.learning_rate else: learning_rate = self.optimizer._decayed_lr(var_dtype=tf.float32) return { "training_loss": loss, "learning_rate": learning_rate, }
def _create_model(params, is_train): """Creates transformer model.""" encdec_kwargs = dict(num_layers=params["num_hidden_layers"], num_attention_heads=params["num_heads"], intermediate_size=params["filter_size"], activation="relu", dropout_rate=params["relu_dropout"], attention_dropout_rate=params["attention_dropout"], use_bias=False, norm_first=True, norm_epsilon=1e-6, intermediate_dropout=params["relu_dropout"]) encoder_layer = models.TransformerEncoder(**encdec_kwargs) decoder_layer = models.TransformerDecoder(**encdec_kwargs) model_kwargs = dict(vocab_size=params["vocab_size"], embedding_width=params["hidden_size"], dropout_rate=params["layer_postprocess_dropout"], padded_decode=params["padded_decode"], decode_max_length=params["decode_max_length"], dtype=params["dtype"], extra_decode_length=params["extra_decode_length"], beam_size=params["beam_size"], alpha=params["alpha"], encoder_layer=encoder_layer, decoder_layer=decoder_layer, name="transformer_v2") if is_train: inputs = tf.keras.layers.Input((None, ), dtype="int64", name="inputs") targets = tf.keras.layers.Input((None, ), dtype="int64", name="targets") internal_model = models.Seq2SeqTransformer(**model_kwargs) logits = internal_model(dict(inputs=inputs, targets=targets), training=is_train) vocab_size = params["vocab_size"] label_smoothing = params["label_smoothing"] if params["enable_metrics_in_training"]: logits = metrics.MetricLayer(vocab_size)([logits, targets]) logits = tf.keras.layers.Lambda(lambda x: x, name="logits", dtype=tf.float32)(logits) model = tf.keras.Model([inputs, targets], logits) loss = metrics.transformer_loss(logits, targets, label_smoothing, vocab_size) model.add_loss(loss) return model batch_size = params["decode_batch_size"] if params[ "padded_decode"] else None inputs = tf.keras.layers.Input((None, ), batch_size=batch_size, dtype="int64", name="inputs") internal_model = models.Seq2SeqTransformer(**model_kwargs) ret = internal_model(dict(inputs=inputs), training=is_train) outputs, scores = ret["outputs"], ret["scores"] return tf.keras.Model(inputs, [outputs, scores])
def _step_fn(inputs): """Per-replica step function.""" inputs, targets = inputs with tf.GradientTape() as tape: logits = model([inputs, targets], training=True) loss = metrics.transformer_loss(logits, targets, params["label_smoothing"], params["vocab_size"]) # Scales the loss, which results in using the average loss across all # of the replicas for backprop. scaled_loss = loss / self.distribution_strategy.num_replicas_in_sync # De-dupes variables due to keras tracking issues. tvars = list({id(v): v for v in model.trainable_variables}.values()) grads = tape.gradient(scaled_loss, tvars) opt.apply_gradients(zip(grads, tvars)) # For reporting, the metric takes the mean of losses. train_loss_metric.update_state(loss)