def test_attention_wrapper_with_multiple_attention_mechanisms(): cell = tf.keras.layers.LSTMCell(5) mechanisms = [wrapper.LuongAttention(units=3), wrapper.LuongAttention(units=3)] # We simply test that the wrapper creation makes no error. wrapper.AttentionWrapper(cell, mechanisms, attention_layer_size=[4, 5]) wrapper.AttentionWrapper( cell, mechanisms, attention_layer=[tf.keras.layers.Dense(4), tf.keras.layers.Dense(5)], )
def testLuongScaledDType(self, dtype): # Test case for GitHub issue 18099 encoder_outputs = self.encoder_outputs.astype(dtype) decoder_inputs = self.decoder_inputs.astype(dtype) attention_mechanism = wrapper.LuongAttention( units=self.units, memory=encoder_outputs, memory_sequence_length=self.encoder_sequence_length, scale=True, dtype=dtype, ) cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid") cell = wrapper.AttentionWrapper(cell, attention_mechanism) sampler = sampler_py.TrainingSampler() my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler) final_outputs, final_state, _ = my_decoder( decoder_inputs, initial_state=cell.get_initial_state(batch_size=self.batch, dtype=dtype), sequence_length=self.decoder_sequence_length) self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) self.assertEqual(final_outputs.rnn_output.dtype, dtype) self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
def test_luong_scaled_dtype(dtype): dummy_data = DummyData2() # Test case for GitHub issue 18099 encoder_outputs = dummy_data.encoder_outputs.astype(dtype) decoder_inputs = dummy_data.decoder_inputs.astype(dtype) attention_mechanism = wrapper.LuongAttention( units=dummy_data.units, memory=encoder_outputs, memory_sequence_length=dummy_data.encoder_sequence_length, scale=True, dtype=dtype, ) cell = tf.keras.layers.LSTMCell(dummy_data.units, recurrent_activation="sigmoid", dtype=dtype) cell = wrapper.AttentionWrapper(cell, attention_mechanism, dtype=dtype) sampler = sampler_py.TrainingSampler() my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler, dtype=dtype) final_outputs, final_state, _ = my_decoder( decoder_inputs, initial_state=cell.get_initial_state(batch_size=dummy_data.batch, dtype=dtype), sequence_length=dummy_data.decoder_sequence_length, ) assert isinstance(final_outputs, basic_decoder.BasicDecoderOutput) assert final_outputs.rnn_output.dtype == dtype assert isinstance(final_state, wrapper.AttentionWrapperState)
def test_custom_attention_layer(): dummy_data = DummyData2() attention_mechanism = wrapper.LuongAttention(dummy_data.units) cell = tf.keras.layers.LSTMCell(dummy_data.units) attention_layer = tf.keras.layers.Dense( dummy_data.units * 2, use_bias=False, activation=tf.math.tanh ) attention_wrapper = wrapper.AttentionWrapper( cell, attention_mechanism, attention_layer=attention_layer ) with pytest.raises(ValueError): # Should fail because the attention mechanism has not been # initialized. attention_wrapper.get_initial_state( batch_size=dummy_data.batch, dtype=tf.float32 ) attention_mechanism.setup_memory( dummy_data.encoder_outputs.astype(np.float32), memory_sequence_length=dummy_data.encoder_sequence_length, ) initial_state = attention_wrapper.get_initial_state( batch_size=dummy_data.batch, dtype=tf.float32 ) assert initial_state.attention.shape[-1] == dummy_data.units * 2 first_input = dummy_data.decoder_inputs[:, 0].astype(np.float32) output, _ = attention_wrapper(first_input, initial_state) assert output.shape[-1] == dummy_data.units * 2
def test_masking(): memory = tf.ones([4, 4, 5], dtype=tf.float32) memory_sequence_length = tf.constant([1, 2, 3, 4], dtype=tf.int32) query = tf.ones([4, 5], dtype=tf.float32) state = None attention = wrapper.LuongAttention(5, memory, memory_sequence_length) alignment, _ = attention([query, state]) assert np.sum(np.triu(alignment, k=1)) == 0
def test_masking(self): memory = tf.ones([4, 4, 5], dtype=tf.float32) memory_sequence_length = tf.constant([1, 2, 3, 4], dtype=tf.int32) query = tf.ones([4, 5], dtype=tf.float32) state = None attention = wrapper.LuongAttention(5, memory, memory_sequence_length) alignment, _ = attention([query, state]) self.evaluate(tf.compat.v1.global_variables_initializer()) alignment = self.evaluate(alignment) self.assertEqual(np.sum(np.triu(alignment, k=1)), 0)
def test_attention_wrapper_with_gru_cell(): mechanism = wrapper.LuongAttention(units=3) cell = tf.keras.layers.GRUCell(3) cell = wrapper.AttentionWrapper(cell, mechanism) memory = tf.ones([2, 5, 3]) inputs = tf.ones([2, 3]) mechanism.setup_memory(memory) initial_state = cell.get_initial_state(inputs=inputs) _, state = cell(inputs, initial_state) tf.nest.assert_same_structure(initial_state, state)
def test_basic_decoder_with_attention_wrapper(): units = 32 vocab_size = 1000 attention_mechanism = attention_wrapper.LuongAttention(units) cell = tf.keras.layers.LSTMCell(units) cell = attention_wrapper.AttentionWrapper(cell, attention_mechanism) output_layer = tf.keras.layers.Dense(vocab_size) sampler = sampler_py.TrainingSampler() # BasicDecoder should accept a non initialized AttentionWrapper. basic_decoder.BasicDecoder(cell, sampler, output_layer=output_layer)
def test_attention_state_with_keras_rnn(): # See https://github.com/tensorflow/addons/issues/1095. cell = tf.keras.layers.LSTMCell(8) mechanism = wrapper.LuongAttention(units=8, memory=tf.ones((2, 4, 8))) cell = wrapper.AttentionWrapper(cell=cell, attention_mechanism=mechanism) layer = tf.keras.layers.RNN(cell) _ = layer(inputs=tf.ones((2, 4, 8))) # Make sure the explicit initial_state also works. initial_state = cell.get_initial_state(batch_size=2, dtype=tf.float32) _ = layer(inputs=tf.ones((2, 4, 8)), initial_state=initial_state)
def test_attention_state_with_variable_length_input(self): cell = tf.keras.layers.LSTMCell(3) mechanism = wrapper.LuongAttention(units=3) cell = wrapper.AttentionWrapper(cell, mechanism) var_len = tf.random.uniform(shape=(), minval=2, maxval=10, dtype=tf.int32) data = tf.ones(shape=(var_len, var_len, 3)) mechanism.setup_memory(data) layer = tf.keras.layers.RNN(cell) _ = layer(data)
def testCustomAttentionLayer(self): attention_mechanism = wrapper.LuongAttention(self.units) cell = tf.keras.layers.LSTMCell(self.units) attention_layer = tf.keras.layers.Dense( self.units * 2, use_bias=False, activation=tf.math.tanh) attention_wrapper = wrapper.AttentionWrapper( cell, attention_mechanism, attention_layer=attention_layer) with self.assertRaises(ValueError): # Should fail because the attention mechanism has not been # initialized. attention_wrapper.get_initial_state( batch_size=self.batch, dtype=tf.float32) attention_mechanism.setup_memory( self.encoder_outputs.astype(np.float32), memory_sequence_length=self.encoder_sequence_length) initial_state = attention_wrapper.get_initial_state( batch_size=self.batch, dtype=tf.float32) self.assertEqual(initial_state.attention.shape[-1], self.units * 2) first_input = self.decoder_inputs[:, 0].astype(np.float32) output, next_state = attention_wrapper(first_input, initial_state) self.assertEqual(output.shape[-1], self.units * 2)