def _create_model(params, is_train): """Creates transformer model.""" encdec_kwargs = dict(num_layers=params["num_hidden_layers"], num_attention_heads=params["num_heads"], intermediate_size=params["filter_size"], activation="relu", dropout_rate=params["relu_dropout"], attention_dropout_rate=params["attention_dropout"], use_bias=False, norm_first=True, norm_epsilon=1e-6, intermediate_dropout=params["relu_dropout"]) encoder_layer = models.TransformerEncoder(**encdec_kwargs) decoder_layer = models.TransformerDecoder(**encdec_kwargs) model_kwargs = dict(vocab_size=params["vocab_size"], embedding_width=params["hidden_size"], dropout_rate=params["layer_postprocess_dropout"], padded_decode=params["padded_decode"], decode_max_length=params["decode_max_length"], dtype=params["dtype"], extra_decode_length=params["extra_decode_length"], beam_size=params["beam_size"], alpha=params["alpha"], encoder_layer=encoder_layer, decoder_layer=decoder_layer, name="transformer_v2") if is_train: inputs = tf.keras.layers.Input((None, ), dtype="int64", name="inputs") targets = tf.keras.layers.Input((None, ), dtype="int64", name="targets") internal_model = models.Seq2SeqTransformer(**model_kwargs) logits = internal_model(dict(inputs=inputs, targets=targets), training=is_train) vocab_size = params["vocab_size"] label_smoothing = params["label_smoothing"] if params["enable_metrics_in_training"]: logits = metrics.MetricLayer(vocab_size)([logits, targets]) logits = tf.keras.layers.Lambda(lambda x: x, name="logits", dtype=tf.float32)(logits) model = tf.keras.Model([inputs, targets], logits) loss = metrics.transformer_loss(logits, targets, label_smoothing, vocab_size) model.add_loss(loss) return model batch_size = params["decode_batch_size"] if params[ "padded_decode"] else None inputs = tf.keras.layers.Input((None, ), batch_size=batch_size, dtype="int64", name="inputs") internal_model = models.Seq2SeqTransformer(**model_kwargs) ret = internal_model(dict(inputs=inputs), training=is_train) outputs, scores = ret["outputs"], ret["scores"] return tf.keras.Model(inputs, [outputs, scores])
def build_model(self, params) -> tf.keras.Model: """Creates model architecture.""" model_cfg = params or self.task_config.model encoder_kwargs = model_cfg.encoder.as_dict() encoder_layer = models.TransformerEncoder(**encoder_kwargs) decoder_kwargs = model_cfg.decoder.as_dict() decoder_layer = models.TransformerDecoder(**decoder_kwargs) return models.Seq2SeqTransformer( vocab_size=self._vocab_size, embedding_width=model_cfg.embedding_width, dropout_rate=model_cfg.dropout_rate, padded_decode=model_cfg.padded_decode, decode_max_length=model_cfg.decode_max_length, beam_size=model_cfg.beam_size, alpha=model_cfg.alpha, encoder_layer=encoder_layer, decoder_layer=decoder_layer)