Пример #1
0
def test_attention_wrapper_with_multiple_attention_mechanisms():
    cell = tf.keras.layers.LSTMCell(5)
    mechanisms = [wrapper.LuongAttention(units=3), wrapper.LuongAttention(units=3)]
    # We simply test that the wrapper creation makes no error.
    wrapper.AttentionWrapper(cell, mechanisms, attention_layer_size=[4, 5])
    wrapper.AttentionWrapper(
        cell,
        mechanisms,
        attention_layer=[tf.keras.layers.Dense(4), tf.keras.layers.Dense(5)],
    )
Пример #2
0
    def testLuongScaledDType(self, dtype):
        # Test case for GitHub issue 18099
        encoder_outputs = self.encoder_outputs.astype(dtype)
        decoder_inputs = self.decoder_inputs.astype(dtype)
        attention_mechanism = wrapper.LuongAttention(
            units=self.units,
            memory=encoder_outputs,
            memory_sequence_length=self.encoder_sequence_length,
            scale=True,
            dtype=dtype,
        )
        cell = keras.layers.LSTMCell(self.units,
                                     recurrent_activation="sigmoid")
        cell = wrapper.AttentionWrapper(cell, attention_mechanism)

        sampler = sampler_py.TrainingSampler()
        my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler)

        final_outputs, final_state, _ = my_decoder(
            decoder_inputs,
            initial_state=cell.get_initial_state(batch_size=self.batch,
                                                 dtype=dtype),
            sequence_length=self.decoder_sequence_length)
        self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput)
        self.assertEqual(final_outputs.rnn_output.dtype, dtype)
        self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
Пример #3
0
def test_luong_scaled_dtype(dtype):
    dummy_data = DummyData2()
    # Test case for GitHub issue 18099
    encoder_outputs = dummy_data.encoder_outputs.astype(dtype)
    decoder_inputs = dummy_data.decoder_inputs.astype(dtype)
    attention_mechanism = wrapper.LuongAttention(
        units=dummy_data.units,
        memory=encoder_outputs,
        memory_sequence_length=dummy_data.encoder_sequence_length,
        scale=True,
        dtype=dtype,
    )
    cell = tf.keras.layers.LSTMCell(dummy_data.units,
                                    recurrent_activation="sigmoid",
                                    dtype=dtype)
    cell = wrapper.AttentionWrapper(cell, attention_mechanism, dtype=dtype)

    sampler = sampler_py.TrainingSampler()
    my_decoder = basic_decoder.BasicDecoder(cell=cell,
                                            sampler=sampler,
                                            dtype=dtype)

    final_outputs, final_state, _ = my_decoder(
        decoder_inputs,
        initial_state=cell.get_initial_state(batch_size=dummy_data.batch,
                                             dtype=dtype),
        sequence_length=dummy_data.decoder_sequence_length,
    )
    assert isinstance(final_outputs, basic_decoder.BasicDecoderOutput)
    assert final_outputs.rnn_output.dtype == dtype
    assert isinstance(final_state, wrapper.AttentionWrapperState)
Пример #4
0
def test_custom_attention_layer():
    dummy_data = DummyData2()
    attention_mechanism = wrapper.LuongAttention(dummy_data.units)
    cell = tf.keras.layers.LSTMCell(dummy_data.units)
    attention_layer = tf.keras.layers.Dense(
        dummy_data.units * 2, use_bias=False, activation=tf.math.tanh
    )
    attention_wrapper = wrapper.AttentionWrapper(
        cell, attention_mechanism, attention_layer=attention_layer
    )
    with pytest.raises(ValueError):
        # Should fail because the attention mechanism has not been
        # initialized.
        attention_wrapper.get_initial_state(
            batch_size=dummy_data.batch, dtype=tf.float32
        )
    attention_mechanism.setup_memory(
        dummy_data.encoder_outputs.astype(np.float32),
        memory_sequence_length=dummy_data.encoder_sequence_length,
    )
    initial_state = attention_wrapper.get_initial_state(
        batch_size=dummy_data.batch, dtype=tf.float32
    )
    assert initial_state.attention.shape[-1] == dummy_data.units * 2
    first_input = dummy_data.decoder_inputs[:, 0].astype(np.float32)
    output, _ = attention_wrapper(first_input, initial_state)
    assert output.shape[-1] == dummy_data.units * 2
Пример #5
0
def test_masking():
    memory = tf.ones([4, 4, 5], dtype=tf.float32)
    memory_sequence_length = tf.constant([1, 2, 3, 4], dtype=tf.int32)
    query = tf.ones([4, 5], dtype=tf.float32)
    state = None
    attention = wrapper.LuongAttention(5, memory, memory_sequence_length)
    alignment, _ = attention([query, state])
    assert np.sum(np.triu(alignment, k=1)) == 0
Пример #6
0
 def test_masking(self):
     memory = tf.ones([4, 4, 5], dtype=tf.float32)
     memory_sequence_length = tf.constant([1, 2, 3, 4], dtype=tf.int32)
     query = tf.ones([4, 5], dtype=tf.float32)
     state = None
     attention = wrapper.LuongAttention(5, memory, memory_sequence_length)
     alignment, _ = attention([query, state])
     self.evaluate(tf.compat.v1.global_variables_initializer())
     alignment = self.evaluate(alignment)
     self.assertEqual(np.sum(np.triu(alignment, k=1)), 0)
Пример #7
0
def test_attention_wrapper_with_gru_cell():
    mechanism = wrapper.LuongAttention(units=3)
    cell = tf.keras.layers.GRUCell(3)
    cell = wrapper.AttentionWrapper(cell, mechanism)
    memory = tf.ones([2, 5, 3])
    inputs = tf.ones([2, 3])
    mechanism.setup_memory(memory)
    initial_state = cell.get_initial_state(inputs=inputs)
    _, state = cell(inputs, initial_state)
    tf.nest.assert_same_structure(initial_state, state)
Пример #8
0
def test_basic_decoder_with_attention_wrapper():
    units = 32
    vocab_size = 1000
    attention_mechanism = attention_wrapper.LuongAttention(units)
    cell = tf.keras.layers.LSTMCell(units)
    cell = attention_wrapper.AttentionWrapper(cell, attention_mechanism)
    output_layer = tf.keras.layers.Dense(vocab_size)
    sampler = sampler_py.TrainingSampler()
    # BasicDecoder should accept a non initialized AttentionWrapper.
    basic_decoder.BasicDecoder(cell, sampler, output_layer=output_layer)
Пример #9
0
def test_attention_state_with_keras_rnn():
    # See https://github.com/tensorflow/addons/issues/1095.
    cell = tf.keras.layers.LSTMCell(8)

    mechanism = wrapper.LuongAttention(units=8, memory=tf.ones((2, 4, 8)))

    cell = wrapper.AttentionWrapper(cell=cell, attention_mechanism=mechanism)

    layer = tf.keras.layers.RNN(cell)
    _ = layer(inputs=tf.ones((2, 4, 8)))

    # Make sure the explicit initial_state also works.
    initial_state = cell.get_initial_state(batch_size=2, dtype=tf.float32)
    _ = layer(inputs=tf.ones((2, 4, 8)), initial_state=initial_state)
Пример #10
0
    def test_attention_state_with_variable_length_input(self):
        cell = tf.keras.layers.LSTMCell(3)
        mechanism = wrapper.LuongAttention(units=3)
        cell = wrapper.AttentionWrapper(cell, mechanism)

        var_len = tf.random.uniform(shape=(),
                                    minval=2,
                                    maxval=10,
                                    dtype=tf.int32)
        data = tf.ones(shape=(var_len, var_len, 3))

        mechanism.setup_memory(data)
        layer = tf.keras.layers.RNN(cell)

        _ = layer(data)
Пример #11
0
 def testCustomAttentionLayer(self):
     attention_mechanism = wrapper.LuongAttention(self.units)
     cell = tf.keras.layers.LSTMCell(self.units)
     attention_layer = tf.keras.layers.Dense(
         self.units * 2, use_bias=False, activation=tf.math.tanh)
     attention_wrapper = wrapper.AttentionWrapper(
         cell, attention_mechanism, attention_layer=attention_layer)
     with self.assertRaises(ValueError):
         # Should fail because the attention mechanism has not been
         # initialized.
         attention_wrapper.get_initial_state(
             batch_size=self.batch, dtype=tf.float32)
     attention_mechanism.setup_memory(
         self.encoder_outputs.astype(np.float32),
         memory_sequence_length=self.encoder_sequence_length)
     initial_state = attention_wrapper.get_initial_state(
         batch_size=self.batch, dtype=tf.float32)
     self.assertEqual(initial_state.attention.shape[-1], self.units * 2)
     first_input = self.decoder_inputs[:, 0].astype(np.float32)
     output, next_state = attention_wrapper(first_input, initial_state)
     self.assertEqual(output.shape[-1], self.units * 2)