def _create_model(params, is_train):
    """Creates transformer model."""

    encdec_kwargs = dict(num_layers=params["num_hidden_layers"],
                         num_attention_heads=params["num_heads"],
                         intermediate_size=params["filter_size"],
                         activation="relu",
                         dropout_rate=params["relu_dropout"],
                         attention_dropout_rate=params["attention_dropout"],
                         use_bias=False,
                         norm_first=True,
                         norm_epsilon=1e-6,
                         intermediate_dropout=params["relu_dropout"])
    encoder_layer = models.TransformerEncoder(**encdec_kwargs)
    decoder_layer = models.TransformerDecoder(**encdec_kwargs)

    model_kwargs = dict(vocab_size=params["vocab_size"],
                        embedding_width=params["hidden_size"],
                        dropout_rate=params["layer_postprocess_dropout"],
                        padded_decode=params["padded_decode"],
                        decode_max_length=params["decode_max_length"],
                        dtype=params["dtype"],
                        extra_decode_length=params["extra_decode_length"],
                        beam_size=params["beam_size"],
                        alpha=params["alpha"],
                        encoder_layer=encoder_layer,
                        decoder_layer=decoder_layer,
                        name="transformer_v2")

    if is_train:
        inputs = tf.keras.layers.Input((None, ), dtype="int64", name="inputs")
        targets = tf.keras.layers.Input((None, ),
                                        dtype="int64",
                                        name="targets")
        internal_model = models.Seq2SeqTransformer(**model_kwargs)
        logits = internal_model(dict(inputs=inputs, targets=targets),
                                training=is_train)
        vocab_size = params["vocab_size"]
        label_smoothing = params["label_smoothing"]
        if params["enable_metrics_in_training"]:
            logits = metrics.MetricLayer(vocab_size)([logits, targets])
        logits = tf.keras.layers.Lambda(lambda x: x,
                                        name="logits",
                                        dtype=tf.float32)(logits)
        model = tf.keras.Model([inputs, targets], logits)
        loss = metrics.transformer_loss(logits, targets, label_smoothing,
                                        vocab_size)
        model.add_loss(loss)
        return model

    batch_size = params["decode_batch_size"] if params[
        "padded_decode"] else None
    inputs = tf.keras.layers.Input((None, ),
                                   batch_size=batch_size,
                                   dtype="int64",
                                   name="inputs")
    internal_model = models.Seq2SeqTransformer(**model_kwargs)
    ret = internal_model(dict(inputs=inputs), training=is_train)
    outputs, scores = ret["outputs"], ret["scores"]
    return tf.keras.Model(inputs, [outputs, scores])
示例#2
0
    def build_model(self, params) -> tf.keras.Model:
        """Creates model architecture."""
        model_cfg = params or self.task_config.model
        encoder_kwargs = model_cfg.encoder.as_dict()
        encoder_layer = models.TransformerEncoder(**encoder_kwargs)
        decoder_kwargs = model_cfg.decoder.as_dict()
        decoder_layer = models.TransformerDecoder(**decoder_kwargs)

        return models.Seq2SeqTransformer(
            vocab_size=self._vocab_size,
            embedding_width=model_cfg.embedding_width,
            dropout_rate=model_cfg.dropout_rate,
            padded_decode=model_cfg.padded_decode,
            decode_max_length=model_cfg.decode_max_length,
            beam_size=model_cfg.beam_size,
            alpha=model_cfg.alpha,
            encoder_layer=encoder_layer,
            decoder_layer=decoder_layer)