Beispiel #1
0
    def testLuongMonotonicScaled(self):
        create_attention_mechanism = wrapper.LuongMonotonicAttention
        create_attention_kwargs = {"scale": True}

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(
                shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.003664831),
            sample_id=ResultSummary(
                shape=(5, 3), dtype=np.dtype("int32"), mean=3.06666666))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(
                    shape=(5, 9), dtype=np.dtype("float32"), mean=0.54318606),
                ResultSummary(
                    shape=(5, 9), dtype=np.dtype("float32"), mean=1.12592840)
            ],
            attention=ResultSummary(
                shape=(5, 6), dtype=np.dtype("float32"), mean=0.059128221),
            time=3,
            alignments=ResultSummary(
                shape=(5, 8), dtype=np.dtype("float32"), mean=0.05112994),
            attention_state=ResultSummary(
                shape=(5, 8), dtype=np.dtype("float32"), mean=0.05112994),
            alignment_history=())
        expected_final_alignment_history = ResultSummary(
            shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.06994973868)

        self._testWithAttention(
            create_attention_mechanism,
            expected_final_output,
            expected_final_state,
            attention_mechanism_depth=9,
            alignment_history=True,
            expected_final_alignment_history=expected_final_alignment_history,
            create_attention_kwargs=create_attention_kwargs)
Beispiel #2
0
    def testLuongNotNormalized(self):
        create_attention_mechanism = wrapper.LuongAttention

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(shape=(5, 3, 6),
                                     dtype=np.dtype("float32"),
                                     mean=-0.06124732),
            sample_id=ResultSummary(shape=(5, 3),
                                    dtype=np.dtype("int32"),
                                    mean=2.73333333))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.52021580),
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=1.0964939)
            ],
            attention=ResultSummary(shape=(5, 6),
                                    dtype=np.dtype("float32"),
                                    mean=-0.0318060),
            time=3,
            alignments=ResultSummary(shape=(5, 8),
                                     dtype=np.dtype("float32"),
                                     mean=0.125),
            attention_state=ResultSummary(shape=(5, 8),
                                          dtype=np.dtype("float32"),
                                          mean=0.125),
            alignment_history=())

        self._testWithAttention(create_attention_mechanism,
                                expected_final_output,
                                expected_final_state,
                                attention_mechanism_depth=9)
Beispiel #3
0
    def testBahdanauNormalized(self):
        create_attention_mechanism = wrapper.BahdanauAttention
        create_attention_kwargs = {"kernel_initializer": "ones", "normalize": True}

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(
                shape=(5, 3, 6), dtype=np.dtype("float32"), mean=-0.008089137
            ),
            sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=2.8),
        )
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.49166861),
                ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=1.01068615),
            ],
            attention=ResultSummary(
                shape=(5, 6), dtype=np.dtype("float32"), mean=0.042427111
            ),
            alignments=ResultSummary(
                shape=(5, 8), dtype=np.dtype("float32"), mean=0.125
            ),
            attention_state=ResultSummary(
                shape=(5, 8), dtype=np.dtype("float32"), mean=0.125
            ),
            alignment_history=(),
        )

        self._testWithAttention(
            create_attention_mechanism,
            expected_final_output,
            expected_final_state,
            create_query_layer=True,
            create_attention_kwargs=create_attention_kwargs,
        )
Beispiel #4
0
    def testNotUseAttentionLayer(self):
        create_attention_mechanism = wrapper.BahdanauAttention
        create_attention_kwargs = {"kernel_initializer": "ones"}

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(
                shape=(5, 3, 10), dtype=np.dtype("float32"), mean=0.078317143
            ),
            sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=4.2),
        )
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.89382392),
                ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=1.722382),
            ],
            attention=ResultSummary(
                shape=(5, 10), dtype=np.dtype("float32"), mean=0.026356646
            ),
            alignments=ResultSummary(
                shape=(5, 8), dtype=np.dtype("float32"), mean=0.125
            ),
            attention_state=ResultSummary(
                shape=(5, 8), dtype=np.dtype("float32"), mean=0.125
            ),
            alignment_history=(),
        )

        self._testWithAttention(
            create_attention_mechanism,
            expected_final_output,
            expected_final_state,
            attention_layer_size=None,
            create_query_layer=True,
            create_attention_kwargs=create_attention_kwargs,
        )
def test_luong_not_normalized():
    set_random_state_for_tf_and_np()
    policy = tf.keras.mixed_precision.experimental.global_policy()
    create_attention_mechanism = wrapper.LuongAttention

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=policy.compute_dtype, mean=-0.06124732
        ),
        sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=2.73333333),
    )
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=[
            ResultSummary(shape=(5, 9), dtype=policy.compute_dtype, mean=0.52021580),
            ResultSummary(shape=(5, 9), dtype=policy.compute_dtype, mean=1.0964939),
        ],
        attention=ResultSummary(
            shape=(5, 6), dtype=policy.compute_dtype, mean=-0.0318060
        ),
        alignments=ResultSummary(shape=(5, 8), dtype=policy.compute_dtype, mean=0.125),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=policy.compute_dtype, mean=0.125
        ),
        alignment_history=(),
    )

    _test_with_attention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        attention_mechanism_depth=9,
    )
def test_luong_scaled():
    set_random_state_for_tf_and_np()
    create_attention_mechanism = wrapper.LuongAttention
    create_attention_kwargs = {"scale": True}

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=-0.06124732
        ),
        sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=2.73333333),
    )
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=[
            ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.52021580),
            ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=1.0964939),
        ],
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=-0.0318060
        ),
        alignments=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.125),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.125
        ),
        alignment_history=(),
    )

    _test_with_attention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        attention_mechanism_depth=9,
        create_attention_kwargs=create_attention_kwargs,
    )
    def testLuongNotNormalized(self):
        create_attention_mechanism = wrapper.LuongAttention

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(shape=(5, 3, 6),
                                     dtype=np.dtype("float32"),
                                     mean=0.05481226),
            sample_id=ResultSummary(shape=(5, 3),
                                    dtype=np.dtype("int32"),
                                    mean=3.13333333))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.38453412),
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.5785929)
            ],
            attention=ResultSummary(shape=(5, 6),
                                    dtype=np.dtype("float32"),
                                    mean=0.16311775),
            time=3,
            alignments=ResultSummary(shape=(5, 8),
                                     dtype=np.dtype("float32"),
                                     mean=0.125),
            attention_state=ResultSummary(shape=(5, 8),
                                          dtype=np.dtype("float32"),
                                          mean=0.125),
            alignment_history=())

        self._testWithAttention(create_attention_mechanism,
                                expected_final_output,
                                expected_final_state,
                                attention_mechanism_depth=9)
Beispiel #8
0
    def testBahdanauNotNormalized(self):
        create_attention_mechanism = wrapper.BahdanauAttention
        create_attention_kwargs = {"kernel_initializer": "ones"}
        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(
                shape=(5, 3, 6), dtype=np.dtype(np.float32),
                mean=-0.003204414),
            sample_id=ResultSummary(
                shape=(5, 3), dtype=np.dtype(np.int32), mean=3.2))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(
                    shape=(5, 9), dtype=np.dtype(np.float32), mean=0.40868404),
                ResultSummary(
                    shape=(5, 9), dtype=np.dtype(np.float32), mean=0.89017969)
            ],
            attention=ResultSummary(
                shape=(5, 6), dtype=np.dtype(np.float32), mean=0.041453815),
            time=3,
            alignments=ResultSummary(
                shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125),
            attention_state=ResultSummary(
                shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125),
            alignment_history=())
        expected_final_alignment_history = ResultSummary(
            shape=(3, 5, 8), dtype=np.dtype(np.float32), mean=0.125)

        self._testWithAttention(
            create_attention_mechanism,
            expected_final_output,
            expected_final_state,
            alignment_history=True,
            create_query_layer=True,
            expected_final_alignment_history=expected_final_alignment_history,
            create_attention_kwargs=create_attention_kwargs)
Beispiel #9
0
    def testBahdanauMonotonicNotNormalized(self):
        create_attention_mechanism = wrapper.BahdanauMonotonicAttention
        create_attention_kwargs = {"kernel_initializer": "ones"}

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(
                shape=(5, 3, 6), dtype=np.dtype("float32"), mean=-0.009921653),
            sample_id=ResultSummary(
                shape=(5, 3), dtype=np.dtype("int32"), mean=3.13333333))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(
                    shape=(5, 9), dtype=np.dtype("float32"), mean=0.44612807),
                ResultSummary(
                    shape=(5, 9), dtype=np.dtype("float32"), mean=0.95786464)
            ],
            attention=ResultSummary(
                shape=(5, 6), dtype=np.dtype("float32"), mean=0.038682378),
            time=3,
            alignments=ResultSummary(
                shape=(5, 8), dtype=np.dtype("float32"), mean=0.09778417),
            attention_state=ResultSummary(
                shape=(5, 8), dtype=np.dtype("float32"), mean=0.09778417),
            alignment_history=())
        expected_final_alignment_history = ResultSummary(
            shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.10261579603)

        self._testWithAttention(
            create_attention_mechanism,
            expected_final_output,
            expected_final_state,
            alignment_history=True,
            expected_final_alignment_history=expected_final_alignment_history,
            create_query_layer=True,
            create_attention_kwargs=create_attention_kwargs)
Beispiel #10
0
    def testLuongMonotonicNotNormalized(self):
        self.skipTest(
            "Resolve https://github.com/tensorflow/addons/issues/781")
        create_attention_mechanism = wrapper.LuongMonotonicAttention

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(
                shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.003664831),
            sample_id=ResultSummary(
                shape=(5, 3), dtype=np.dtype("int32"), mean=3.06666666))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(
                    shape=(5, 9), dtype=np.dtype("float32"), mean=0.54318606),
                ResultSummary(
                    shape=(5, 9), dtype=np.dtype("float32"), mean=1.12592840)
            ],
            attention=ResultSummary(
                shape=(5, 6), dtype=np.dtype("float32"), mean=0.059128221),
            time=3,
            alignments=ResultSummary(
                shape=(5, 8), dtype=np.dtype("float32"), mean=0.05112994),
            attention_state=ResultSummary(
                shape=(5, 8), dtype=np.dtype("float32"), mean=0.05112994),
            alignment_history=())
        expected_final_alignment_history = ResultSummary(
            shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.06994973868)

        self._testWithAttention(
            create_attention_mechanism,
            expected_final_output,
            expected_final_state,
            attention_mechanism_depth=9,
            alignment_history=True,
            expected_final_alignment_history=expected_final_alignment_history)
Beispiel #11
0
    def testBahdanauMonotonicNormalized(self):
        create_attention_mechanism = wrapper.BahdanauMonotonicAttention
        create_attention_kwargs = {
            "kernel_initializer": "ones",
            "normalize": True
        }
        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(shape=(5, 3, 6),
                                     dtype=np.dtype("float32"),
                                     mean=0.007140680),
            sample_id=ResultSummary(shape=(5, 3),
                                    dtype=np.dtype("int32"),
                                    mean=3.26666666),
        )
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.47012400),
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=1.0249618),
            ],
            attention=ResultSummary(shape=(5, 6),
                                    dtype=np.dtype("float32"),
                                    mean=0.068432882),
            time=3,
            alignments=ResultSummary(shape=(5, 8),
                                     dtype=np.dtype("float32"),
                                     mean=0.0615656),
            attention_state=ResultSummary(shape=(5, 8),
                                          dtype=np.dtype("float32"),
                                          mean=0.0615656),
            alignment_history=(),
        )
        expected_final_alignment_history = ResultSummary(
            shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.07909643)

        self._testWithAttention(
            create_attention_mechanism,
            expected_final_output,
            expected_final_state,
            alignment_history=True,
            expected_final_alignment_history=expected_final_alignment_history,
            create_query_layer=True,
            create_attention_kwargs=create_attention_kwargs,
        )
Beispiel #12
0
def test_bahdanau_not_normalized():
    set_random_state_for_tf_and_np()
    create_attention_mechanism = wrapper.BahdanauAttention
    create_attention_kwargs = {"kernel_initializer": "ones"}
    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(shape=(5, 3, 6),
                                 dtype=np.dtype(np.float32),
                                 mean=-0.003204414),
        sample_id=ResultSummary(shape=(5, 3),
                                dtype=np.dtype(np.int32),
                                mean=3.2),
    )
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=[
            ResultSummary(shape=(5, 9),
                          dtype=np.dtype(np.float32),
                          mean=0.40868404),
            ResultSummary(shape=(5, 9),
                          dtype=np.dtype(np.float32),
                          mean=0.89017969),
        ],
        attention=ResultSummary(shape=(5, 6),
                                dtype=np.dtype(np.float32),
                                mean=0.041453815),
        alignments=ResultSummary(shape=(5, 8),
                                 dtype=np.dtype(np.float32),
                                 mean=0.125),
        attention_state=ResultSummary(shape=(5, 8),
                                      dtype=np.dtype(np.float32),
                                      mean=0.125),
        alignment_history=(),
    )
    expected_final_alignment_history = ResultSummary(shape=(3, 5, 8),
                                                     dtype=np.dtype(
                                                         np.float32),
                                                     mean=0.125)

    _test_with_attention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        alignment_history=True,
        create_query_layer=True,
        expected_final_alignment_history=expected_final_alignment_history,
        create_attention_kwargs=create_attention_kwargs,
    )
Beispiel #13
0
def test_luong_monotonic_not_normalized():
    set_random_state_for_tf_and_np()
    create_attention_mechanism = wrapper.LuongMonotonicAttention

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(shape=(5, 3, 6),
                                 dtype=np.dtype("float32"),
                                 mean=0.003664831),
        sample_id=ResultSummary(shape=(5, 3),
                                dtype=np.dtype("int32"),
                                mean=3.06666666),
    )
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=[
            ResultSummary(shape=(5, 9),
                          dtype=np.dtype("float32"),
                          mean=0.54318606),
            ResultSummary(shape=(5, 9),
                          dtype=np.dtype("float32"),
                          mean=1.12592840),
        ],
        attention=ResultSummary(shape=(5, 6),
                                dtype=np.dtype("float32"),
                                mean=0.059128221),
        alignments=ResultSummary(shape=(5, 8),
                                 dtype=np.dtype("float32"),
                                 mean=0.05112994),
        attention_state=ResultSummary(shape=(5, 8),
                                      dtype=np.dtype("float32"),
                                      mean=0.05112994),
        alignment_history=(),
    )
    expected_final_alignment_history = ResultSummary(shape=(3, 5, 8),
                                                     dtype=np.dtype("float32"),
                                                     mean=0.06994973868)

    _test_with_attention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        attention_mechanism_depth=9,
        alignment_history=True,
        expected_final_alignment_history=expected_final_alignment_history,
    )
    def testBahdanauMonotonicNormalized(self):
        create_attention_mechanism = wrapper.BahdanauMonotonicAttention
        create_attention_kwargs = {
            "kernel_initializer": "ones",
            "normalize": True
        }
        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(shape=(5, 3, 6),
                                     dtype=np.dtype("float32"),
                                     mean=0.043294173),
            sample_id=ResultSummary(shape=(5, 3),
                                    dtype=np.dtype("int32"),
                                    mean=3.53333333))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.40034312),
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.5925445)
            ],
            attention=ResultSummary(shape=(5, 6),
                                    dtype=np.dtype("float32"),
                                    mean=0.096119694),
            time=3,
            alignments=ResultSummary(shape=(5, 8),
                                     dtype=np.dtype("float32"),
                                     mean=0.1211452),
            attention_state=ResultSummary(shape=(5, 8),
                                          dtype=np.dtype("float32"),
                                          mean=0.1211452),
            alignment_history=())
        expected_final_alignment_history = ResultSummary(
            shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12258384)

        self._testWithAttention(
            create_attention_mechanism,
            expected_final_output,
            expected_final_state,
            alignment_history=True,
            expected_final_alignment_history=expected_final_alignment_history,
            create_query_layer=True,
            create_attention_kwargs=create_attention_kwargs)
    def testLuongMonotonicScaled(self):
        create_attention_mechanism = wrapper.LuongMonotonicAttention
        create_attention_kwargs = {"scale": True}

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(shape=(5, 3, 6),
                                     dtype=np.dtype("float32"),
                                     mean=0.027387079),
            sample_id=ResultSummary(shape=(5, 3),
                                    dtype=np.dtype("int32"),
                                    mean=3.13333333))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.32660431),
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.52464348)
            ],
            attention=ResultSummary(shape=(5, 6),
                                    dtype=np.dtype("float32"),
                                    mean=0.089345723),
            time=3,
            alignments=ResultSummary(shape=(5, 8),
                                     dtype=np.dtype("float32"),
                                     mean=0.11831035),
            attention_state=ResultSummary(shape=(5, 8),
                                          dtype=np.dtype("float32"),
                                          mean=0.11831035),
            alignment_history=())
        expected_final_alignment_history = ResultSummary(
            shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12194442004)

        self._testWithAttention(
            create_attention_mechanism,
            expected_final_output,
            expected_final_state,
            attention_mechanism_depth=9,
            alignment_history=True,
            expected_final_alignment_history=expected_final_alignment_history,
            create_attention_kwargs=create_attention_kwargs)
    def testBahdanauMonotonicNotNormalized(self):
        create_attention_mechanism = wrapper.BahdanauMonotonicAttention
        create_attention_kwargs = {"kernel_initializer": "ones"}

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(shape=(5, 3, 6),
                                     dtype=np.dtype("float32"),
                                     mean=0.041342419),
            sample_id=ResultSummary(shape=(5, 3),
                                    dtype=np.dtype("int32"),
                                    mean=3.53333333))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.33866978),
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.46913195)
            ],
            attention=ResultSummary(shape=(5, 6),
                                    dtype=np.dtype("float32"),
                                    mean=0.092498459),
            time=3,
            alignments=ResultSummary(shape=(5, 8),
                                     dtype=np.dtype("float32"),
                                     mean=0.12079944),
            attention_state=ResultSummary(shape=(5, 8),
                                          dtype=np.dtype("float32"),
                                          mean=0.12079944),
            alignment_history=())
        expected_final_alignment_history = ResultSummary(
            shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.121448785067)

        self._testWithAttention(
            create_attention_mechanism,
            expected_final_output,
            expected_final_state,
            alignment_history=True,
            expected_final_alignment_history=expected_final_alignment_history,
            create_query_layer=True,
            create_attention_kwargs=create_attention_kwargs)
    def testBahdanauNormalized(self):
        create_attention_mechanism = wrapper.BahdanauAttention
        create_attention_kwargs = {
            "kernel_initializer": "ones",
            "normalize": True
        }

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(shape=(5, 3, 6),
                                     dtype=np.dtype("float32"),
                                     mean=0.047594748),
            sample_id=ResultSummary(shape=(5, 3),
                                    dtype=np.dtype("int32"),
                                    mean=3.6))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.41311637),
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.61683208)
            ],
            attention=ResultSummary(shape=(5, 6),
                                    dtype=np.dtype("float32"),
                                    mean=0.090581432),
            time=3,
            alignments=ResultSummary(shape=(5, 8),
                                     dtype=np.dtype("float32"),
                                     mean=0.125),
            attention_state=ResultSummary(shape=(5, 8),
                                          dtype=np.dtype("float32"),
                                          mean=0.125),
            alignment_history=())

        self._testWithAttention(
            create_attention_mechanism,
            expected_final_output,
            expected_final_state,
            create_query_layer=True,
            create_attention_kwargs=create_attention_kwargs)
    def testNotUseAttentionLayer(self):
        create_attention_mechanism = wrapper.BahdanauAttention
        create_attention_kwargs = {"kernel_initializer": "ones"}

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(shape=(5, 3, 10),
                                     dtype=np.dtype("float32"),
                                     mean=0.072406612),
            sample_id=ResultSummary(shape=(5, 3),
                                    dtype=np.dtype("int32"),
                                    mean=3.86666666))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.61177742),
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=1.032002)
            ],
            attention=ResultSummary(shape=(5, 10),
                                    dtype=np.dtype("float32"),
                                    mean=0.011346335),
            time=3,
            alignments=ResultSummary(shape=(5, 8),
                                     dtype=np.dtype("float32"),
                                     mean=0.125),
            attention_state=ResultSummary(shape=(5, 8),
                                          dtype=np.dtype("float32"),
                                          mean=0.125),
            alignment_history=())

        self._testWithAttention(
            create_attention_mechanism,
            expected_final_output,
            expected_final_state,
            attention_layer_size=None,
            create_query_layer=True,
            create_attention_kwargs=create_attention_kwargs)