def test_multi_doc_decoder(self): self._config = utils.get_test_params(cls=configs.NHNetConfig) seq_length = 10 num_docs = 5 encoder_input_ids = tf.keras.layers.Input(shape=(num_docs, seq_length), name="encoder_input_ids", dtype=tf.int32) target_ids = tf.keras.layers.Input(shape=(seq_length, ), name="target_ids", dtype=tf.int32) encoder_outputs = tf.keras.layers.Input( shape=(num_docs, seq_length, self._config.hidden_size), name="all_encoder_outputs", dtype=tf.float32) embedding_lookup = layers.OnDeviceEmbedding( vocab_size=self._config.vocab_size, embedding_width=self._config.hidden_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=self._config.initializer_range), name="word_embeddings") doc_attention_probs = tf.keras.layers.Input( shape=(self._config.num_decoder_attn_heads, seq_length, num_docs), name="doc_attention_probs", dtype=tf.float32) cross_attention_bias = decoder.AttentionBias( bias_type="multi_cross")(encoder_input_ids) self_attention_bias = decoder.AttentionBias( bias_type="decoder_self")(target_ids) inputs = dict(attention_bias=cross_attention_bias, self_attention_bias=self_attention_bias, target_ids=target_ids, all_encoder_outputs=encoder_outputs, doc_attention_probs=doc_attention_probs) decoder_layer = decoder.Decoder(self._config, embedding_lookup) outputs = decoder_layer(inputs) model_inputs = dict(encoder_input_ids=encoder_input_ids, target_ids=target_ids, all_encoder_outputs=encoder_outputs, doc_attention_probs=doc_attention_probs) model = tf.keras.Model(inputs=model_inputs, outputs=outputs, name="test") self.assertLen(decoder_layer.trainable_weights, 30) # Forward path. fake_inputs = { "encoder_input_ids": np.zeros((2, num_docs, seq_length), dtype=np.int32), "target_ids": np.zeros((2, seq_length), dtype=np.int32), "all_encoder_outputs": np.zeros((2, num_docs, seq_length, 16), dtype=np.float32), "doc_attention_probs": np.zeros( (2, self._config.num_decoder_attn_heads, seq_length, num_docs), dtype=np.float32) } output_tensor = model(fake_inputs) self.assertEqual(output_tensor.shape, (2, seq_length, 16))
def get_bert2bert_layers(params: configs.BERT2BERTConfig): """Creates a Bert2Bert stem model and returns Bert encoder/decoder. We use funtional-style to create stem model because we need to make all layers built to restore variables in a customized way. The layers are called with placeholder inputs to make them fully built. Args: params: ParamsDict. Returns: two keras Layers, bert_model_layer and decoder_layer """ input_ids = tf.keras.layers.Input(shape=(None, ), name="input_ids", dtype=tf.int32) input_mask = tf.keras.layers.Input(shape=(None, ), name="input_mask", dtype=tf.int32) segment_ids = tf.keras.layers.Input(shape=(None, ), name="segment_ids", dtype=tf.int32) target_ids = tf.keras.layers.Input(shape=(None, ), name="target_ids", dtype=tf.int32) bert_config = utils.get_bert_config_from_params(params) bert_model_layer = networks.BertEncoder( vocab_size=bert_config.vocab_size, hidden_size=bert_config.hidden_size, num_layers=bert_config.num_hidden_layers, num_attention_heads=bert_config.num_attention_heads, intermediate_size=bert_config.intermediate_size, activation=tf_utils.get_activation(bert_config.hidden_act), dropout_rate=bert_config.hidden_dropout_prob, attention_dropout_rate=bert_config.attention_probs_dropout_prob, max_sequence_length=bert_config.max_position_embeddings, type_vocab_size=bert_config.type_vocab_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=bert_config.initializer_range), return_all_encoder_outputs=True, name="bert_encoder") all_encoder_outputs, _ = bert_model_layer( [input_ids, input_mask, segment_ids]) # pylint: disable=protected-access decoder_layer = decoder.Decoder(params, bert_model_layer._embedding_layer) # pylint: enable=protected-access cross_attention_bias = decoder.AttentionBias( bias_type="single_cross")(input_ids) self_attention_bias = decoder.AttentionBias( bias_type="decoder_self")(target_ids) decoder_inputs = dict(attention_bias=cross_attention_bias, self_attention_bias=self_attention_bias, target_ids=target_ids, all_encoder_outputs=all_encoder_outputs) _ = decoder_layer(decoder_inputs) return bert_model_layer, decoder_layer
def get_nhnet_layers(params: configs.NHNetConfig): """Creates a Mult-doc encoder/decoder. Args: params: ParamsDict. Returns: two keras Layers, bert_model_layer and decoder_layer """ input_ids = tf.keras.layers.Input( shape=(None,), name="input_ids", dtype=tf.int32) input_mask = tf.keras.layers.Input( shape=(None,), name="input_mask", dtype=tf.int32) segment_ids = tf.keras.layers.Input( shape=(None,), name="segment_ids", dtype=tf.int32) bert_config = utils.get_bert_config_from_params(params) bert_model_layer = networks.BertEncoder( vocab_size=bert_config.vocab_size, hidden_size=bert_config.hidden_size, num_layers=bert_config.num_hidden_layers, num_attention_heads=bert_config.num_attention_heads, intermediate_size=bert_config.intermediate_size, activation=tf_utils.get_activation(bert_config.hidden_act), dropout_rate=bert_config.hidden_dropout_prob, attention_dropout_rate=bert_config.attention_probs_dropout_prob, max_sequence_length=bert_config.max_position_embeddings, type_vocab_size=bert_config.type_vocab_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=bert_config.initializer_range), return_all_encoder_outputs=True, name="bert_encoder") bert_model_layer([input_ids, input_mask, segment_ids]) input_ids = tf.keras.layers.Input( shape=(None, None), name="input_ids", dtype=tf.int32) all_encoder_outputs = tf.keras.layers.Input((None, None, params.hidden_size), dtype=tf.float32) target_ids = tf.keras.layers.Input( shape=(None,), name="target_ids", dtype=tf.int32) doc_attention_probs = tf.keras.layers.Input( (params.num_decoder_attn_heads, None, None), dtype=tf.float32) # pylint: disable=protected-access decoder_layer = decoder.Decoder(params, bert_model_layer._embedding_layer) # pylint: enable=protected-access cross_attention_bias = decoder.AttentionBias(bias_type="multi_cross")( input_ids) self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")( target_ids) decoder_inputs = dict( attention_bias=cross_attention_bias, self_attention_bias=self_attention_bias, target_ids=target_ids, all_encoder_outputs=all_encoder_outputs, doc_attention_probs=doc_attention_probs) _ = decoder_layer(decoder_inputs) return bert_model_layer, decoder_layer
def test_bert_decoder(self): seq_length = 10 encoder_input_ids = tf.keras.layers.Input(shape=(seq_length, ), name="encoder_input_ids", dtype=tf.int32) target_ids = tf.keras.layers.Input(shape=(seq_length, ), name="target_ids", dtype=tf.int32) encoder_outputs = tf.keras.layers.Input( shape=(seq_length, self._config.hidden_size), name="all_encoder_outputs", dtype=tf.float32) embedding_lookup = layers.OnDeviceEmbedding( vocab_size=self._config.vocab_size, embedding_width=self._config.hidden_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=self._config.initializer_range), name="word_embeddings") cross_attention_bias = decoder.AttentionBias( bias_type="single_cross")(encoder_input_ids) self_attention_bias = decoder.AttentionBias( bias_type="decoder_self")(target_ids) inputs = dict(attention_bias=cross_attention_bias, self_attention_bias=self_attention_bias, target_ids=target_ids, all_encoder_outputs=encoder_outputs) decoder_layer = decoder.Decoder(self._config, embedding_lookup) outputs = decoder_layer(inputs) model_inputs = dict(encoder_input_ids=encoder_input_ids, target_ids=target_ids, all_encoder_outputs=encoder_outputs) model = tf.keras.Model(inputs=model_inputs, outputs=outputs, name="test") self.assertLen(decoder_layer.trainable_weights, 30) # Forward path. fake_inputs = { "encoder_input_ids": np.zeros((2, 10), dtype=np.int32), "target_ids": np.zeros((2, 10), dtype=np.int32), "all_encoder_outputs": np.zeros((2, 10, 16), dtype=np.float32), } output_tensor = model(fake_inputs) self.assertEqual(output_tensor.shape, (2, 10, 16))