def get_nhnet_layers(params: configs.NHNetConfig): """Creates a Mult-doc encoder/decoder. Args: params: ParamsDict. Returns: two keras Layers, bert_model_layer and decoder_layer """ input_ids = tf.keras.layers.Input( shape=(None,), name="input_ids", dtype=tf.int32) input_mask = tf.keras.layers.Input( shape=(None,), name="input_mask", dtype=tf.int32) segment_ids = tf.keras.layers.Input( shape=(None,), name="segment_ids", dtype=tf.int32) bert_config = utils.get_bert_config_from_params(params) bert_model_layer = networks.TransformerEncoder( vocab_size=bert_config.vocab_size, hidden_size=bert_config.hidden_size, num_layers=bert_config.num_hidden_layers, num_attention_heads=bert_config.num_attention_heads, intermediate_size=bert_config.intermediate_size, activation=tf_utils.get_activation(bert_config.hidden_act), dropout_rate=bert_config.hidden_dropout_prob, attention_dropout_rate=bert_config.attention_probs_dropout_prob, sequence_length=None, max_sequence_length=bert_config.max_position_embeddings, type_vocab_size=bert_config.type_vocab_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=bert_config.initializer_range), return_all_encoder_outputs=True, name="bert_encoder") bert_model_layer([input_ids, input_mask, segment_ids]) input_ids = tf.keras.layers.Input( shape=(None, None), name="input_ids", dtype=tf.int32) all_encoder_outputs = tf.keras.layers.Input((None, None, params.hidden_size), dtype=tf.float32) target_ids = tf.keras.layers.Input( shape=(None,), name="target_ids", dtype=tf.int32) doc_attention_probs = tf.keras.layers.Input( (params.num_decoder_attn_heads, None, None), dtype=tf.float32) # pylint: disable=protected-access decoder_layer = decoder.Decoder(params, bert_model_layer._embedding_layer) # pylint: enable=protected-access cross_attention_bias = decoder.AttentionBias(bias_type="multi_cross")( input_ids) self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")( target_ids) decoder_inputs = dict( attention_bias=cross_attention_bias, self_attention_bias=self_attention_bias, target_ids=target_ids, all_encoder_outputs=all_encoder_outputs, doc_attention_probs=doc_attention_probs) _ = decoder_layer(decoder_inputs) return bert_model_layer, decoder_layer
def get_bert2bert_layers(params: configs.BERT2BERTConfig): """Creates a Bert2Bert stem model and returns Bert encoder/decoder. We use funtional-style to create stem model because we need to make all layers built to restore variables in a customized way. The layers are called with placeholder inputs to make them fully built. Args: params: ParamsDict. Returns: two keras Layers, bert_model_layer and decoder_layer """ input_ids = tf.keras.layers.Input( shape=(None,), name="input_ids", dtype=tf.int32) input_mask = tf.keras.layers.Input( shape=(None,), name="input_mask", dtype=tf.int32) segment_ids = tf.keras.layers.Input( shape=(None,), name="segment_ids", dtype=tf.int32) target_ids = tf.keras.layers.Input( shape=(None,), name="target_ids", dtype=tf.int32) bert_config = utils.get_bert_config_from_params(params) bert_model_layer = networks.TransformerEncoder( vocab_size=bert_config.vocab_size, hidden_size=bert_config.hidden_size, num_layers=bert_config.num_hidden_layers, num_attention_heads=bert_config.num_attention_heads, intermediate_size=bert_config.intermediate_size, activation=tf_utils.get_activation(bert_config.hidden_act), dropout_rate=bert_config.hidden_dropout_prob, attention_dropout_rate=bert_config.attention_probs_dropout_prob, <<<<<<< HEAD sequence_length=None, ======= >>>>>>> a811a3b7e640722318ad868c99feddf3f3063e36 max_sequence_length=bert_config.max_position_embeddings, type_vocab_size=bert_config.type_vocab_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=bert_config.initializer_range), return_all_encoder_outputs=True, name="bert_encoder") all_encoder_outputs, _ = bert_model_layer( [input_ids, input_mask, segment_ids]) # pylint: disable=protected-access decoder_layer = decoder.Decoder(params, bert_model_layer._embedding_layer) # pylint: enable=protected-access cross_attention_bias = decoder.AttentionBias(bias_type="single_cross")( input_ids) self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")( target_ids) decoder_inputs = dict( attention_bias=cross_attention_bias, self_attention_bias=self_attention_bias, target_ids=target_ids, all_encoder_outputs=all_encoder_outputs) _ = decoder_layer(decoder_inputs) return bert_model_layer, decoder_layer