def create_model(params, is_train): """Creates transformer model.""" with tf.name_scope("model"): if is_train: inputs = tf.keras.layers.Input((None, ), dtype="int64", name="inputs") targets = tf.keras.layers.Input((None, ), dtype="int64", name="targets") internal_model = Transformer(params, name="transformer_v2") logits = internal_model([inputs, targets], training=is_train) vocab_size = params["vocab_size"] label_smoothing = params["label_smoothing"] if params["enable_metrics_in_training"]: logits = metrics.MetricLayer(vocab_size)([logits, targets]) logits = tf.keras.layers.Lambda(lambda x: x, name="logits", dtype=tf.float32)(logits) model = tf.keras.Model([inputs, targets], logits) # TODO(reedwm): Can we do this loss in float16 instead of float32? loss = metrics.transformer_loss(logits, targets, label_smoothing, vocab_size) model.add_loss(loss) return model else: inputs = tf.keras.layers.Input((None, ), dtype="int64", name="inputs") internal_model = Transformer(params, name="transformer_v2") ret = internal_model([inputs], training=is_train) outputs, scores = ret["outputs"], ret["scores"] return tf.keras.Model(inputs, [outputs, scores])
def _step_fn(inputs): """Per-replica step function.""" inputs, targets = inputs with tf.GradientTape() as tape: logits = model([inputs, targets], training=True) loss = metrics.transformer_loss(logits, targets, params["label_smoothing"], params["vocab_size"]) # Scales the loss, which results in using the average loss across all # of the replicas for backprop. scaled_loss = loss / self.distribution_strategy.num_replicas_in_sync # De-dupes variables due to keras tracking issues. tvars = list({id(v): v for v in model.trainable_variables}.values()) grads = tape.gradient(scaled_loss, tvars) opt.apply_gradients(zip(grads, tvars)) # For reporting, the metric takes the mean of losses. train_loss_metric.update_state(loss)