Beispiel #1
0
def create_model(params, is_train):
    """Creates transformer model."""
    with tf.name_scope("model"):
        if is_train:
            inputs = tf.keras.layers.Input((None, ),
                                           dtype="int64",
                                           name="inputs")
            targets = tf.keras.layers.Input((None, ),
                                            dtype="int64",
                                            name="targets")
            internal_model = Transformer(params, name="transformer_v2")
            logits = internal_model([inputs, targets], training=is_train)
            vocab_size = params["vocab_size"]
            label_smoothing = params["label_smoothing"]
            if params["enable_metrics_in_training"]:
                logits = metrics.MetricLayer(vocab_size)([logits, targets])
            logits = tf.keras.layers.Lambda(lambda x: x,
                                            name="logits",
                                            dtype=tf.float32)(logits)
            model = tf.keras.Model([inputs, targets], logits)
            # TODO(reedwm): Can we do this loss in float16 instead of float32?
            loss = metrics.transformer_loss(logits, targets, label_smoothing,
                                            vocab_size)
            model.add_loss(loss)
            return model

        else:
            inputs = tf.keras.layers.Input((None, ),
                                           dtype="int64",
                                           name="inputs")
            internal_model = Transformer(params, name="transformer_v2")
            ret = internal_model([inputs], training=is_train)
            outputs, scores = ret["outputs"], ret["scores"]
            return tf.keras.Model(inputs, [outputs, scores])
            def _step_fn(inputs):
                """Per-replica step function."""
                inputs, targets = inputs
                with tf.GradientTape() as tape:
                    logits = model([inputs, targets], training=True)
                    loss = metrics.transformer_loss(logits, targets,
                                                    params["label_smoothing"],
                                                    params["vocab_size"])
                    # Scales the loss, which results in using the average loss across all
                    # of the replicas for backprop.
                    scaled_loss = loss / self.distribution_strategy.num_replicas_in_sync

                # De-dupes variables due to keras tracking issues.
                tvars = list({id(v): v for v in model.trainable_variables}.values())
                grads = tape.gradient(scaled_loss, tvars)
                opt.apply_gradients(zip(grads, tvars))
                # For reporting, the metric takes the mean of losses.
                train_loss_metric.update_state(loss)