def test_attention_wrapper_with_multiple_attention_mechanisms():
    cell = tf.keras.layers.LSTMCell(5)
    mechanisms = [wrapper.LuongAttention(units=3), wrapper.LuongAttention(units=3)]
    # We simply test that the wrapper creation makes no error.
    wrapper.AttentionWrapper(cell, mechanisms, attention_layer_size=[4, 5])
    wrapper.AttentionWrapper(
        cell,
        mechanisms,
        attention_layer=[tf.keras.layers.Dense(4), tf.keras.layers.Dense(5)],
    )
Beispiel #2
0
    def testLuongScaledDType(self, dtype):
        # Test case for GitHub issue 18099
        encoder_outputs = self.encoder_outputs.astype(dtype)
        decoder_inputs = self.decoder_inputs.astype(dtype)
        attention_mechanism = wrapper.LuongAttention(
            units=self.units,
            memory=encoder_outputs,
            memory_sequence_length=self.encoder_sequence_length,
            scale=True,
            dtype=dtype,
        )
        cell = keras.layers.LSTMCell(self.units,
                                     recurrent_activation="sigmoid")
        cell = wrapper.AttentionWrapper(cell, attention_mechanism)

        sampler = sampler_py.TrainingSampler()
        my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler)

        final_outputs, final_state, _ = my_decoder(
            decoder_inputs,
            initial_state=cell.get_initial_state(batch_size=self.batch,
                                                 dtype=dtype),
            sequence_length=self.decoder_sequence_length)
        self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput)
        self.assertEqual(final_outputs.rnn_output.dtype, dtype)
        self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
Beispiel #3
0
def test_luong_scaled_dtype(dtype):
    dummy_data = DummyData2()
    # Test case for GitHub issue 18099
    encoder_outputs = dummy_data.encoder_outputs.astype(dtype)
    decoder_inputs = dummy_data.decoder_inputs.astype(dtype)
    attention_mechanism = wrapper.LuongAttention(
        units=dummy_data.units,
        memory=encoder_outputs,
        memory_sequence_length=dummy_data.encoder_sequence_length,
        scale=True,
        dtype=dtype,
    )
    cell = tf.keras.layers.LSTMCell(dummy_data.units,
                                    recurrent_activation="sigmoid",
                                    dtype=dtype)
    cell = wrapper.AttentionWrapper(cell, attention_mechanism, dtype=dtype)

    sampler = sampler_py.TrainingSampler()
    my_decoder = basic_decoder.BasicDecoder(cell=cell,
                                            sampler=sampler,
                                            dtype=dtype)

    final_outputs, final_state, _ = my_decoder(
        decoder_inputs,
        initial_state=cell.get_initial_state(batch_size=dummy_data.batch,
                                             dtype=dtype),
        sequence_length=dummy_data.decoder_sequence_length,
    )
    assert isinstance(final_outputs, basic_decoder.BasicDecoderOutput)
    assert final_outputs.rnn_output.dtype == dtype
    assert isinstance(final_state, wrapper.AttentionWrapperState)
Beispiel #4
0
    def testBahdanauNormalizedDType(self, dtype):
        encoder_outputs = self.encoder_outputs.astype(dtype)
        decoder_inputs = self.decoder_inputs.astype(dtype)
        attention_mechanism = wrapper.BahdanauAttention(
            units=self.units,
            memory=encoder_outputs,
            memory_sequence_length=self.encoder_sequence_length,
            normalize=True,
            dtype=dtype,
        )
        cell = tf.keras.layers.LSTMCell(
            self.units, recurrent_activation="sigmoid", dtype=dtype
        )
        cell = wrapper.AttentionWrapper(cell, attention_mechanism, dtype=dtype)

        sampler = sampler_py.TrainingSampler()
        my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler, dtype=dtype)

        final_outputs, final_state, _ = my_decoder(
            decoder_inputs,
            initial_state=cell.get_initial_state(batch_size=self.batch, dtype=dtype),
            sequence_length=self.decoder_sequence_length,
        )
        self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput)
        self.assertEqual(final_outputs.rnn_output.dtype, dtype)
        self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
def test_custom_attention_layer():
    dummy_data = DummyData2()
    attention_mechanism = wrapper.LuongAttention(dummy_data.units)
    cell = tf.keras.layers.LSTMCell(dummy_data.units)
    attention_layer = tf.keras.layers.Dense(
        dummy_data.units * 2, use_bias=False, activation=tf.math.tanh
    )
    attention_wrapper = wrapper.AttentionWrapper(
        cell, attention_mechanism, attention_layer=attention_layer
    )
    with pytest.raises(ValueError):
        # Should fail because the attention mechanism has not been
        # initialized.
        attention_wrapper.get_initial_state(
            batch_size=dummy_data.batch, dtype=tf.float32
        )
    attention_mechanism.setup_memory(
        dummy_data.encoder_outputs.astype(np.float32),
        memory_sequence_length=dummy_data.encoder_sequence_length,
    )
    initial_state = attention_wrapper.get_initial_state(
        batch_size=dummy_data.batch, dtype=tf.float32
    )
    assert initial_state.attention.shape[-1] == dummy_data.units * 2
    first_input = dummy_data.decoder_inputs[:, 0].astype(np.float32)
    output, _ = attention_wrapper(first_input, initial_state)
    assert output.shape[-1] == dummy_data.units * 2
Beispiel #6
0
def test_attention_wrapper_with_gru_cell():
    mechanism = wrapper.LuongAttention(units=3)
    cell = tf.keras.layers.GRUCell(3)
    cell = wrapper.AttentionWrapper(cell, mechanism)
    memory = tf.ones([2, 5, 3])
    inputs = tf.ones([2, 3])
    mechanism.setup_memory(memory)
    initial_state = cell.get_initial_state(inputs=inputs)
    _, state = cell(inputs, initial_state)
    tf.nest.assert_same_structure(initial_state, state)
Beispiel #7
0
def test_basic_decoder_with_attention_wrapper():
    units = 32
    vocab_size = 1000
    attention_mechanism = attention_wrapper.LuongAttention(units)
    cell = tf.keras.layers.LSTMCell(units)
    cell = attention_wrapper.AttentionWrapper(cell, attention_mechanism)
    output_layer = tf.keras.layers.Dense(vocab_size)
    sampler = sampler_py.TrainingSampler()
    # BasicDecoder should accept a non initialized AttentionWrapper.
    basic_decoder.BasicDecoder(cell, sampler, output_layer=output_layer)
Beispiel #8
0
def test_attention_state_with_keras_rnn():
    # See https://github.com/tensorflow/addons/issues/1095.
    cell = tf.keras.layers.LSTMCell(8)

    mechanism = wrapper.LuongAttention(units=8, memory=tf.ones((2, 4, 8)))

    cell = wrapper.AttentionWrapper(cell=cell, attention_mechanism=mechanism)

    layer = tf.keras.layers.RNN(cell)
    _ = layer(inputs=tf.ones((2, 4, 8)))

    # Make sure the explicit initial_state also works.
    initial_state = cell.get_initial_state(batch_size=2, dtype=tf.float32)
    _ = layer(inputs=tf.ones((2, 4, 8)), initial_state=initial_state)
Beispiel #9
0
    def test_attention_state_with_variable_length_input(self):
        cell = tf.keras.layers.LSTMCell(3)
        mechanism = wrapper.LuongAttention(units=3)
        cell = wrapper.AttentionWrapper(cell, mechanism)

        var_len = tf.random.uniform(shape=(),
                                    minval=2,
                                    maxval=10,
                                    dtype=tf.int32)
        data = tf.ones(shape=(var_len, var_len, 3))

        mechanism.setup_memory(data)
        layer = tf.keras.layers.RNN(cell)

        _ = layer(data)
Beispiel #10
0
 def testCustomAttentionLayer(self):
     attention_mechanism = wrapper.LuongAttention(self.units)
     cell = tf.keras.layers.LSTMCell(self.units)
     attention_layer = tf.keras.layers.Dense(
         self.units * 2, use_bias=False, activation=tf.math.tanh)
     attention_wrapper = wrapper.AttentionWrapper(
         cell, attention_mechanism, attention_layer=attention_layer)
     with self.assertRaises(ValueError):
         # Should fail because the attention mechanism has not been
         # initialized.
         attention_wrapper.get_initial_state(
             batch_size=self.batch, dtype=tf.float32)
     attention_mechanism.setup_memory(
         self.encoder_outputs.astype(np.float32),
         memory_sequence_length=self.encoder_sequence_length)
     initial_state = attention_wrapper.get_initial_state(
         batch_size=self.batch, dtype=tf.float32)
     self.assertEqual(initial_state.attention.shape[-1], self.units * 2)
     first_input = self.decoder_inputs[:, 0].astype(np.float32)
     output, next_state = attention_wrapper(first_input, initial_state)
     self.assertEqual(output.shape[-1], self.units * 2)
Beispiel #11
0
    def _testWithMaybeMultiAttention(self,
                                     is_multi,
                                     create_attention_mechanisms,
                                     expected_final_output,
                                     expected_final_state,
                                     attention_mechanism_depths,
                                     alignment_history=False,
                                     expected_final_alignment_history=None,
                                     attention_layer_sizes=None,
                                     attention_layers=None,
                                     create_query_layer=False,
                                     create_memory_layer=True,
                                     create_attention_kwargs=None):
        # Allow is_multi to be True with a single mechanism to enable test for
        # passing in a single mechanism in a list.
        assert len(create_attention_mechanisms) == 1 or is_multi
        encoder_sequence_length = [3, 2, 3, 1, 1]
        decoder_sequence_length = [2, 0, 1, 2, 3]
        batch_size = 5
        encoder_max_time = 8
        decoder_max_time = 4
        input_depth = 7
        encoder_output_depth = 10
        cell_depth = 9
        create_attention_kwargs = create_attention_kwargs or {}

        if attention_layer_sizes is not None:
            # Compute sum of attention_layer_sizes. Use encoder_output_depth if
            # None.
            attention_depth = sum(
                attention_layer_size or encoder_output_depth
                for attention_layer_size in attention_layer_sizes)
        elif attention_layers is not None:
            # Compute sum of attention_layers output depth.
            attention_depth = sum(
                attention_layer.compute_output_shape(
                    [batch_size, cell_depth +
                     encoder_output_depth]).dims[-1].value
                for attention_layer in attention_layers)
        else:
            attention_depth = encoder_output_depth * len(
                create_attention_mechanisms)

        decoder_inputs = np.random.randn(batch_size, decoder_max_time,
                                         input_depth).astype(np.float32)
        encoder_outputs = np.random.randn(batch_size, encoder_max_time,
                                          encoder_output_depth).astype(
                                              np.float32)

        attention_mechanisms = []
        for creator, depth in zip(create_attention_mechanisms,
                                  attention_mechanism_depths):
            # Create a memory layer with deterministic initializer to avoid
            # randomness in the test between graph and eager.
            if create_query_layer:
                create_attention_kwargs["query_layer"] = keras.layers.Dense(
                    depth, kernel_initializer="ones", use_bias=False)
            if create_memory_layer:
                create_attention_kwargs["memory_layer"] = keras.layers.Dense(
                    depth, kernel_initializer="ones", use_bias=False)

            attention_mechanisms.append(
                creator(units=depth,
                        memory=encoder_outputs,
                        memory_sequence_length=encoder_sequence_length,
                        **create_attention_kwargs))

        with self.cached_session(use_gpu=True):
            attention_layer_size = attention_layer_sizes
            attention_layer = attention_layers
            if not is_multi:
                if attention_layer_size is not None:
                    attention_layer_size = attention_layer_size[0]
                if attention_layer is not None:
                    attention_layer = attention_layer[0]
            cell = keras.layers.LSTMCell(cell_depth,
                                         recurrent_activation="sigmoid",
                                         kernel_initializer="ones",
                                         recurrent_initializer="ones")
            cell = wrapper.AttentionWrapper(
                cell,
                attention_mechanisms if is_multi else attention_mechanisms[0],
                attention_layer_size=attention_layer_size,
                alignment_history=alignment_history,
                attention_layer=attention_layer)
            if cell._attention_layers is not None:
                for layer in cell._attention_layers:
                    layer.kernel_initializer = initializers.glorot_uniform(
                        seed=1337)

            sampler = sampler_py.TrainingSampler()
            my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler)
            initial_state = cell.get_initial_state(dtype=tf.float32,
                                                   batch_size=batch_size)
            final_outputs, final_state, _ = my_decoder(
                decoder_inputs,
                initial_state=initial_state,
                sequence_length=decoder_sequence_length)

            self.assertIsInstance(final_outputs,
                                  basic_decoder.BasicDecoderOutput)
            self.assertIsInstance(final_state, wrapper.AttentionWrapperState)

            expected_time = (expected_final_state.time
                             if tf.executing_eagerly() else None)
            self.assertEqual(
                (batch_size, expected_time, attention_depth),
                tuple(final_outputs.rnn_output.get_shape().as_list()))
            self.assertEqual(
                (batch_size, expected_time),
                tuple(final_outputs.sample_id.get_shape().as_list()))

            self.assertEqual(
                (batch_size, attention_depth),
                tuple(final_state.attention.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state[0].get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state[1].get_shape().as_list()))

            if alignment_history:
                if is_multi:
                    state_alignment_history = []
                    for history_array in final_state.alignment_history:
                        history = history_array.stack()
                        self.assertEqual(
                            (expected_time, batch_size, encoder_max_time),
                            tuple(history.get_shape().as_list()))
                        state_alignment_history.append(history)
                    state_alignment_history = tuple(state_alignment_history)
                else:
                    state_alignment_history = \
                        final_state.alignment_history.stack()
                    self.assertEqual(
                        (expected_time, batch_size, encoder_max_time),
                        tuple(state_alignment_history.get_shape().as_list()))
                tf.nest.assert_same_structure(
                    cell.state_size,
                    cell.get_initial_state(batch_size=batch_size,
                                           dtype=tf.float32))
                # Remove the history from final_state for purposes of the
                # remainder of the tests.
                final_state = final_state._replace(alignment_history=())  # pylint: disable=protected-access
            else:
                state_alignment_history = ()

            self.evaluate(tf.compat.v1.global_variables_initializer())
            eval_result = self.evaluate({
                "final_outputs":
                final_outputs,
                "final_state":
                final_state,
                "state_alignment_history":
                state_alignment_history,
            })

            final_output_info = tf.nest.map_structure(
                get_result_summary, eval_result["final_outputs"])
            final_state_info = tf.nest.map_structure(
                get_result_summary, eval_result["final_state"])
            print("final_output_info: ", final_output_info)
            print("final_state_info: ", final_state_info)

            tf.nest.map_structure(self.assertAllCloseOrEqual,
                                  expected_final_output, final_output_info)
            tf.nest.map_structure(self.assertAllCloseOrEqual,
                                  expected_final_state, final_state_info)
            # by default, the wrapper emits attention as output
            if alignment_history:
                final_alignment_history_info = tf.nest.map_structure(
                    get_result_summary, eval_result["state_alignment_history"])
                print("final_alignment_history_info: ",
                      final_alignment_history_info)
                tf.nest.map_structure(
                    self.assertAllCloseOrEqual,
                    # outputs are batch major but the stacked TensorArray is
                    # time major
                    expected_final_alignment_history,
                    final_alignment_history_info)
Beispiel #12
0
def _test_with_attention(
    create_attention_mechanism,
    expected_final_output,
    expected_final_state,
    attention_mechanism_depth=3,
    alignment_history=False,
    expected_final_alignment_history=None,
    attention_layer_size=6,
    attention_layer=None,
    create_query_layer=False,
    create_memory_layer=True,
    create_attention_kwargs=None,
):
    attention_layer_sizes = ([attention_layer_size]
                             if attention_layer_size is not None else None)
    attention_layers = [attention_layer
                        ] if attention_layer is not None else None
    create_attention_mechanisms = [create_attention_mechanism]
    attention_mechanism_depths = [attention_mechanism_depth]
    assert len(create_attention_mechanisms) == 1
    encoder_sequence_length = [3, 2, 3, 1, 1]
    decoder_sequence_length = [2, 0, 1, 2, 3]
    batch_size = 5
    encoder_max_time = 8
    decoder_max_time = 4
    input_depth = 7
    encoder_output_depth = 10
    cell_depth = 9
    create_attention_kwargs = create_attention_kwargs or {}

    if attention_layer_sizes is not None:
        # Compute sum of attention_layer_sizes. Use encoder_output_depth if
        # None.
        attention_depth = sum(
            attention_layer_size or encoder_output_depth
            for attention_layer_size in attention_layer_sizes)
    elif attention_layers is not None:
        # Compute sum of attention_layers output depth.
        attention_depth = sum(
            attention_layer.compute_output_shape(
                [batch_size, cell_depth + encoder_output_depth]).dims[-1].value
            for attention_layer in attention_layers)
    else:
        attention_depth = encoder_output_depth * len(
            create_attention_mechanisms)

    decoder_inputs = np.random.randn(batch_size, decoder_max_time,
                                     input_depth).astype(np.float32)
    encoder_outputs = np.random.randn(batch_size, encoder_max_time,
                                      encoder_output_depth).astype(np.float32)

    attention_mechanisms = []
    for creator, depth in zip(create_attention_mechanisms,
                              attention_mechanism_depths):
        # Create a memory layer with deterministic initializer to avoid
        # randomness in the test between graph and eager.
        if create_query_layer:
            create_attention_kwargs["query_layer"] = tf.keras.layers.Dense(
                depth, kernel_initializer="ones", use_bias=False)
        if create_memory_layer:
            create_attention_kwargs["memory_layer"] = tf.keras.layers.Dense(
                depth, kernel_initializer="ones", use_bias=False)

        attention_mechanisms.append(
            creator(
                units=depth,
                memory=encoder_outputs,
                memory_sequence_length=encoder_sequence_length,
                **create_attention_kwargs,
            ))

    attention_layer_size = attention_layer_sizes
    attention_layer = attention_layers
    if attention_layer_size is not None:
        attention_layer_size = attention_layer_size[0]
    if attention_layer is not None:
        attention_layer = attention_layer[0]
    cell = tf.keras.layers.LSTMCell(
        cell_depth,
        recurrent_activation="sigmoid",
        kernel_initializer="ones",
        recurrent_initializer="ones",
    )
    cell = wrapper.AttentionWrapper(
        cell,
        attention_mechanisms[0],
        attention_layer_size=attention_layer_size,
        alignment_history=alignment_history,
        attention_layer=attention_layer,
    )
    if cell._attention_layers is not None:
        for layer in cell._attention_layers:
            layer.kernel_initializer = tf.compat.v1.keras.initializers.glorot_uniform(
                seed=1337)

    sampler = sampler_py.TrainingSampler()
    my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler)
    initial_state = cell.get_initial_state(dtype=tf.float32,
                                           batch_size=batch_size)
    final_outputs, final_state, _ = my_decoder(
        decoder_inputs,
        initial_state=initial_state,
        sequence_length=decoder_sequence_length,
    )

    assert isinstance(final_outputs, basic_decoder.BasicDecoderOutput)
    assert isinstance(final_state, wrapper.AttentionWrapperState)

    expected_time = max(decoder_sequence_length)
    assert (batch_size, expected_time, attention_depth) == tuple(
        final_outputs.rnn_output.get_shape().as_list())
    assert (batch_size, expected_time) == tuple(
        final_outputs.sample_id.get_shape().as_list())

    assert (batch_size, attention_depth) == tuple(
        final_state.attention.get_shape().as_list())
    assert (batch_size, cell_depth) == tuple(
        final_state.cell_state[0].get_shape().as_list())
    assert (batch_size, cell_depth) == tuple(
        final_state.cell_state[1].get_shape().as_list())

    if alignment_history:
        state_alignment_history = final_state.alignment_history.stack()
        assert (expected_time, batch_size, encoder_max_time) == tuple(
            state_alignment_history.get_shape().as_list())
        tf.nest.assert_same_structure(
            cell.state_size,
            cell.get_initial_state(batch_size=batch_size, dtype=tf.float32),
        )
        # Remove the history from final_state for purposes of the
        # remainder of the tests.
        final_state = final_state._replace(alignment_history=())  # pylint: disable=protected-access
    else:
        state_alignment_history = ()

    final_outputs = tf.nest.map_structure(np.array, final_outputs)
    final_state = tf.nest.map_structure(np.array, final_state)
    state_alignment_history = tf.nest.map_structure(np.array,
                                                    state_alignment_history)
    final_output_info = tf.nest.map_structure(get_result_summary,
                                              final_outputs)

    final_state_info = tf.nest.map_structure(get_result_summary, final_state)

    tf.nest.map_structure(assert_allclose_or_equal, expected_final_output,
                          final_output_info)
    tf.nest.map_structure(assert_allclose_or_equal, expected_final_state,
                          final_state_info)
    # by default, the wrapper emits attention as output
    if alignment_history:
        final_alignment_history_info = tf.nest.map_structure(
            get_result_summary, state_alignment_history)
        tf.nest.map_structure(
            assert_allclose_or_equal,
            # outputs are batch major but the stacked TensorArray is
            # time major
            expected_final_alignment_history,
            final_alignment_history_info,
        )
Beispiel #13
0
    def _testDynamicDecodeRNN(self,
                              time_major,
                              has_attention,
                              with_alignment_history=False):
        encoder_sequence_length = np.array([3, 2, 3, 1, 1])
        decoder_sequence_length = np.array([2, 0, 1, 2, 3])
        batch_size = 5
        decoder_max_time = 4
        input_depth = 7
        cell_depth = 9
        attention_depth = 6
        vocab_size = 20
        end_token = vocab_size - 1
        start_token = 0
        embedding_dim = 50
        max_out = max(decoder_sequence_length)
        output_layer = tf.keras.layers.Dense(vocab_size,
                                             use_bias=True,
                                             activation=None)
        beam_width = 3

        with self.cached_session():
            batch_size_tensor = tf.constant(batch_size)
            embedding = np.random.randn(vocab_size,
                                        embedding_dim).astype(np.float32)
            cell = tf.keras.layers.LSTMCell(cell_depth)
            initial_state = cell.get_initial_state(batch_size=batch_size,
                                                   dtype=tf.float32)
            coverage_penalty_weight = 0.0
            if has_attention:
                coverage_penalty_weight = 0.2
                inputs = tf.compat.v1.placeholder_with_default(
                    np.random.randn(batch_size, decoder_max_time,
                                    input_depth).astype(np.float32),
                    shape=(None, None, input_depth))
                tiled_inputs = beam_search_decoder.tile_batch(
                    inputs, multiplier=beam_width)
                tiled_sequence_length = beam_search_decoder.tile_batch(
                    encoder_sequence_length, multiplier=beam_width)
                attention_mechanism = attention_wrapper.BahdanauAttention(
                    units=attention_depth,
                    memory=tiled_inputs,
                    memory_sequence_length=tiled_sequence_length)
                initial_state = beam_search_decoder.tile_batch(
                    initial_state, multiplier=beam_width)
                cell = attention_wrapper.AttentionWrapper(
                    cell=cell,
                    attention_mechanism=attention_mechanism,
                    attention_layer_size=attention_depth,
                    alignment_history=with_alignment_history)
            cell_state = cell.get_initial_state(batch_size=batch_size_tensor *
                                                beam_width,
                                                dtype=tf.float32)
            if has_attention:
                cell_state = cell_state.clone(cell_state=initial_state)
            bsd = beam_search_decoder.BeamSearchDecoder(
                cell=cell,
                beam_width=beam_width,
                output_layer=output_layer,
                length_penalty_weight=0.0,
                coverage_penalty_weight=coverage_penalty_weight,
                output_time_major=time_major,
                maximum_iterations=max_out)

            final_outputs, final_state, final_sequence_lengths = bsd(
                embedding,
                start_tokens=tf.fill([batch_size_tensor], start_token),
                end_token=end_token,
                initial_state=cell_state)

            def _t(shape):
                if time_major:
                    return (shape[1], shape[0]) + shape[2:]
                return shape

            self.assertIsInstance(
                final_outputs,
                beam_search_decoder.FinalBeamSearchDecoderOutput)
            self.assertIsInstance(final_state,
                                  beam_search_decoder.BeamSearchDecoderState)

            beam_search_decoder_output = \
                final_outputs.beam_search_decoder_output
            expected_seq_length = 3 if tf.executing_eagerly() else None
            self.assertEqual(
                _t((batch_size, expected_seq_length, beam_width)),
                tuple(beam_search_decoder_output.scores.get_shape().as_list()))
            self.assertEqual(
                _t((batch_size, expected_seq_length, beam_width)),
                tuple(final_outputs.predicted_ids.get_shape().as_list()))

            self.evaluate(tf.compat.v1.global_variables_initializer())
            eval_results = self.evaluate({
                'final_outputs':
                final_outputs,
                'final_sequence_lengths':
                final_sequence_lengths
            })

            max_sequence_length = np.max(
                eval_results['final_sequence_lengths'])

            # A smoke test
            self.assertEqual(
                _t((batch_size, max_sequence_length, beam_width)),
                eval_results['final_outputs'].beam_search_decoder_output.
                scores.shape)
            self.assertEqual(
                _t((batch_size, max_sequence_length, beam_width)),
                eval_results['final_outputs'].beam_search_decoder_output.
                predicted_ids.shape)
def test_beam_search_decoder(cell_class, time_major, has_attention,
                             with_alignment_history):
    encoder_sequence_length = np.array([3, 2, 3, 1, 1])
    batch_size = 5
    decoder_max_time = 4
    input_depth = 7
    cell_depth = 9
    attention_depth = 6
    vocab_size = 20
    end_token = vocab_size - 1
    start_token = 0
    embedding_dim = 50
    maximum_iterations = 3
    output_layer = tf.keras.layers.Dense(vocab_size,
                                         use_bias=True,
                                         activation=None)
    beam_width = 3
    embedding = tf.random.normal([vocab_size, embedding_dim])
    cell = cell_class(cell_depth)

    if has_attention:
        attention_mechanism = attention_wrapper.BahdanauAttention(
            units=attention_depth, )
        cell = attention_wrapper.AttentionWrapper(
            cell=cell,
            attention_mechanism=attention_mechanism,
            attention_layer_size=attention_depth,
            alignment_history=with_alignment_history,
        )
        coverage_penalty_weight = 0.2
    else:
        coverage_penalty_weight = 0.0

    bsd = beam_search_decoder.BeamSearchDecoder(
        cell=cell,
        beam_width=beam_width,
        output_layer=output_layer,
        length_penalty_weight=0.0,
        coverage_penalty_weight=coverage_penalty_weight,
        output_time_major=time_major,
        maximum_iterations=maximum_iterations,
    )

    @tf.function(input_signature=(
        tf.TensorSpec([None, None, input_depth], dtype=tf.float32),
        tf.TensorSpec([None], dtype=tf.int32),
    ))
    def _beam_decode_from(memory, memory_sequence_length):
        batch_size_tensor = tf.shape(memory)[0]

        if has_attention:
            tiled_memory = beam_search_decoder.tile_batch(
                memory, multiplier=beam_width)
            tiled_memory_sequence_length = beam_search_decoder.tile_batch(
                memory_sequence_length, multiplier=beam_width)
            attention_mechanism.setup_memory(
                tiled_memory,
                memory_sequence_length=tiled_memory_sequence_length)

        cell_state = cell.get_initial_state(batch_size=batch_size_tensor *
                                            beam_width,
                                            dtype=tf.float32)

        return bsd(
            embedding,
            start_tokens=tf.fill([batch_size_tensor], start_token),
            end_token=end_token,
            initial_state=cell_state,
        )

    memory = tf.random.normal([batch_size, decoder_max_time, input_depth])
    memory_sequence_length = tf.constant(encoder_sequence_length,
                                         dtype=tf.int32)
    final_outputs, final_state, final_sequence_lengths = _beam_decode_from(
        memory, memory_sequence_length)

    def _t(shape):
        if time_major:
            return (shape[1], shape[0]) + shape[2:]
        return shape

    assert isinstance(final_outputs,
                      beam_search_decoder.FinalBeamSearchDecoderOutput)
    assert isinstance(final_state, beam_search_decoder.BeamSearchDecoderState)

    beam_search_decoder_output = final_outputs.beam_search_decoder_output
    max_sequence_length = np.max(final_sequence_lengths.numpy())
    assert _t((batch_size, max_sequence_length, beam_width)) == tuple(
        beam_search_decoder_output.scores.shape.as_list())
    assert _t(
        (batch_size, max_sequence_length,
         beam_width)) == tuple(final_outputs.predicted_ids.shape.as_list())
Beispiel #15
0
def test_dynamic_decode_rnn(cell_class, time_major, has_attention,
                            with_alignment_history):
    encoder_sequence_length = np.array([3, 2, 3, 1, 1])
    decoder_sequence_length = np.array([2, 0, 1, 2, 3])
    batch_size = 5
    decoder_max_time = 4
    input_depth = 7
    cell_depth = 9
    attention_depth = 6
    vocab_size = 20
    end_token = vocab_size - 1
    start_token = 0
    embedding_dim = 50
    max_out = max(decoder_sequence_length)
    output_layer = tf.keras.layers.Dense(vocab_size,
                                         use_bias=True,
                                         activation=None)
    beam_width = 3

    batch_size_tensor = tf.constant(batch_size)
    embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
    cell = cell_class(cell_depth)
    initial_state = cell.get_initial_state(batch_size=batch_size,
                                           dtype=tf.float32)
    coverage_penalty_weight = 0.0
    if has_attention:
        coverage_penalty_weight = 0.2
        inputs = tf.compat.v1.placeholder_with_default(
            np.random.randn(batch_size, decoder_max_time,
                            input_depth).astype(np.float32),
            shape=(None, None, input_depth),
        )
        tiled_inputs = beam_search_decoder.tile_batch(inputs,
                                                      multiplier=beam_width)
        tiled_sequence_length = beam_search_decoder.tile_batch(
            encoder_sequence_length, multiplier=beam_width)
        attention_mechanism = attention_wrapper.BahdanauAttention(
            units=attention_depth,
            memory=tiled_inputs,
            memory_sequence_length=tiled_sequence_length,
        )
        initial_state = beam_search_decoder.tile_batch(initial_state,
                                                       multiplier=beam_width)
        cell = attention_wrapper.AttentionWrapper(
            cell=cell,
            attention_mechanism=attention_mechanism,
            attention_layer_size=attention_depth,
            alignment_history=with_alignment_history,
        )
    cell_state = cell.get_initial_state(batch_size=batch_size_tensor *
                                        beam_width,
                                        dtype=tf.float32)
    if has_attention:
        cell_state = cell_state.clone(cell_state=initial_state)
    bsd = beam_search_decoder.BeamSearchDecoder(
        cell=cell,
        beam_width=beam_width,
        output_layer=output_layer,
        length_penalty_weight=0.0,
        coverage_penalty_weight=coverage_penalty_weight,
        output_time_major=time_major,
        maximum_iterations=max_out,
    )

    final_outputs, final_state, final_sequence_lengths = bsd(
        embedding,
        start_tokens=tf.fill([batch_size_tensor], start_token),
        end_token=end_token,
        initial_state=cell_state,
    )

    def _t(shape):
        if time_major:
            return (shape[1], shape[0]) + shape[2:]
        return shape

    assert isinstance(final_outputs,
                      beam_search_decoder.FinalBeamSearchDecoderOutput)
    assert isinstance(final_state, beam_search_decoder.BeamSearchDecoderState)

    beam_search_decoder_output = final_outputs.beam_search_decoder_output
    expected_seq_length = 3 if tf.executing_eagerly() else None
    assert _t((batch_size, expected_seq_length, beam_width)) == tuple(
        beam_search_decoder_output.scores.shape.as_list())
    assert _t(
        (batch_size, expected_seq_length,
         beam_width)) == tuple(final_outputs.predicted_ids.shape.as_list())

    eval_results = {
        "final_outputs": final_outputs,
        "final_sequence_lengths": final_sequence_lengths.numpy(),
    }

    max_sequence_length = np.max(eval_results["final_sequence_lengths"])

    # A smoke test
    assert (_t(
        (batch_size, max_sequence_length, beam_width)
    ) == eval_results["final_outputs"].beam_search_decoder_output.scores.shape)
    assert (_t((batch_size, max_sequence_length,
                beam_width)) == eval_results["final_outputs"].
            beam_search_decoder_output.predicted_ids.shape)