Exemplo n.º 1
0
  def testBahdanauMonotonicNormalized(self):
    create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2
    create_attention_kwargs = {"kernel_initializer": "ones",
                               "normalize": True}
    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=4.5706983),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=rnn_cell.LSTMStateTuple(
            c=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=1.6005473),
            h=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.77863038)),
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=7.3326721),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.12258384),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.12258384),
        alignment_history=())
    expected_final_alignment_history = ResultSummary(
        shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12258384)

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        alignment_history=True,
        expected_final_alignment_history=expected_final_alignment_history,
        create_query_layer=True,
        create_attention_kwargs=create_attention_kwargs)
Exemplo n.º 2
0
    def testLuongNotNormalized(self):
        create_attention_mechanism = wrapper.LuongAttentionV2

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(shape=(5, 3, 6),
                                     dtype=np.dtype("float32"),
                                     mean=0.05481226),
            sample_id=ResultSummary(shape=(5, 3),
                                    dtype=np.dtype("int32"),
                                    mean=3.13333333))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.38453412),
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.5785929)
            ],
            attention=ResultSummary(shape=(5, 6),
                                    dtype=np.dtype("float32"),
                                    mean=0.16311775),
            time=3,
            alignments=ResultSummary(shape=(5, 8),
                                     dtype=np.dtype("float32"),
                                     mean=0.125),
            attention_state=ResultSummary(shape=(5, 8),
                                          dtype=np.dtype("float32"),
                                          mean=0.125),
            alignment_history=())

        self._testWithAttention(create_attention_mechanism,
                                expected_final_output,
                                expected_final_state,
                                attention_mechanism_depth=9)
  def testLuongMonotonicScaled(self):
    create_attention_mechanism = wrapper.LuongMonotonicAttentionV2
    create_attention_kwargs = {"scale": True}

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.027387079),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=3.13333333))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=[
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.32660431),
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.52464348)],
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=0.089345723),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.11831035),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.11831035),
        alignment_history=())
    expected_final_alignment_history = ResultSummary(
        shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12194442004)

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        attention_mechanism_depth=9,
        alignment_history=True,
        expected_final_alignment_history=expected_final_alignment_history,
        create_attention_kwargs=create_attention_kwargs)
  def testBahdanauMonotonicNotNormalized(self):
    create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2
    create_attention_kwargs = {"kernel_initializer": "ones"}

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.041342419),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=3.53333333))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=[
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.33866978),
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.46913195)],
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=0.092498459),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.12079944),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.12079944),
        alignment_history=())
    expected_final_alignment_history = ResultSummary(
        shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.121448785067)

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        alignment_history=True,
        expected_final_alignment_history=expected_final_alignment_history,
        create_query_layer=True,
        create_attention_kwargs=create_attention_kwargs)
  def testBahdanauMonotonicNormalized(self):
    create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2
    create_attention_kwargs = {"kernel_initializer": "ones",
                               "normalize": True}
    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.043294173),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=3.53333333))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=[
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.40034312),
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.5925445)],
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=0.096119694),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.1211452),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.1211452),
        alignment_history=())
    expected_final_alignment_history = ResultSummary(
        shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12258384)

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        alignment_history=True,
        expected_final_alignment_history=expected_final_alignment_history,
        create_query_layer=True,
        create_attention_kwargs=create_attention_kwargs)
  def testBahdanauNormalized(self):
    create_attention_mechanism = wrapper.BahdanauAttentionV2
    create_attention_kwargs = {"kernel_initializer": "ones", "normalize": True}

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.047594748),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=3.6))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=[
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.41311637),
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.61683208)],
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=0.090581432),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.125),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.125),
        alignment_history=())

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        create_query_layer=True,
        create_attention_kwargs=create_attention_kwargs)
  def testNotUseAttentionLayer(self):
    create_attention_mechanism = wrapper.BahdanauAttentionV2
    create_attention_kwargs = {"kernel_initializer": "ones"}

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 10), dtype=np.dtype("float32"), mean=0.072406612),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=3.86666666))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=[
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.61177742),
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=1.032002)],
        attention=ResultSummary(
            shape=(5, 10), dtype=np.dtype("float32"), mean=0.011346335),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.125),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.125),
        alignment_history=())

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        attention_layer_size=None,
        create_query_layer=True,
        create_attention_kwargs=create_attention_kwargs)
Exemplo n.º 8
0
  def testBahdanauMonotonicNotNormalized(self):
    create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2
    create_attention_kwargs = {"kernel_initializer": "ones"}

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=5.9850435),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=rnn_cell.LSTMStateTuple(
            c=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=1.6752492),
            h=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.76052248)),
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=8.361186),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.10989678),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.10989678),
        alignment_history=())
    expected_final_alignment_history = ResultSummary(
        shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.117412611)

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        alignment_history=True,
        expected_final_alignment_history=expected_final_alignment_history,
        create_query_layer=True,
        create_attention_kwargs=create_attention_kwargs)
  def testBahdanauNotNormalized(self):
    create_attention_mechanism = wrapper.BahdanauAttentionV2
    create_attention_kwargs = {"kernel_initializer": "ones"}
    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=0.051747426),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype(np.int32), mean=3.33333333))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=[
            ResultSummary(
                shape=(5, 9), dtype=np.dtype(np.float32), mean=0.44189346),
            ResultSummary(
                shape=(5, 9), dtype=np.dtype(np.float32), mean=0.65429491)],
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype(np.float32), mean=0.073610783),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125),
        alignment_history=())
    expected_final_alignment_history = ResultSummary(
        shape=(3, 5, 8), dtype=np.dtype(np.float32), mean=0.125)

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        alignment_history=True,
        create_query_layer=True,
        expected_final_alignment_history=expected_final_alignment_history,
        create_attention_kwargs=create_attention_kwargs)
Exemplo n.º 10
0
  def testBahdanauNormalized(self):
    create_attention_mechanism = wrapper.BahdanauAttentionV2
    create_attention_kwargs = {"kernel_initializer": "ones", "normalize": True}

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=3.9548259),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=rnn_cell.LSTMStateTuple(
            c=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=1.4652209),
            h=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.70997983)),
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=6.3075728),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.125),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.125),
        alignment_history=())

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        create_query_layer=True,
        create_attention_kwargs=create_attention_kwargs)
Exemplo n.º 11
0
  def testLuongScaled(self):
    create_attention_mechanism = wrapper.LuongAttentionV2
    create_attention_kwargs = {"scale": True}

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=2.6605489),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=rnn_cell.LSTMStateTuple(
            c=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547),
            h=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088)),
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=4.0846314),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.125),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.125),
        alignment_history=())

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        attention_mechanism_depth=9,
        create_attention_kwargs=create_attention_kwargs)
Exemplo n.º 12
0
  def testBahdanauNotNormalized(self):
    create_attention_mechanism = wrapper.BahdanauAttentionV2
    create_attention_kwargs = {"kernel_initializer": "ones"}
    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=4.8290324),
        sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype(np.int32), mean=0))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=rnn_cell.LSTMStateTuple(
            c=ResultSummary(
                shape=(5, 9), dtype=np.dtype(np.float32), mean=1.6432636),
            h=ResultSummary(
                shape=(5, 9), dtype=np.dtype(np.float32), mean=0.75866824)),
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype(np.float32), mean=6.7445569),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125),
        alignment_history=())
    expected_final_alignment_history = ResultSummary(
        shape=(3, 5, 8), dtype=np.dtype(np.float32), mean=0.125)

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        alignment_history=True,
        create_query_layer=True,
        expected_final_alignment_history=expected_final_alignment_history,
        create_attention_kwargs=create_attention_kwargs)
Exemplo n.º 13
0
  def testLuongMonotonicScaled(self):
    create_attention_mechanism = wrapper.LuongMonotonicAttentionV2
    create_attention_kwargs = {"scale": True}

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=3.159497),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=rnn_cell.LSTMStateTuple(
            c=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384),
            h=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038)),
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.11467695),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.11467695),
        alignment_history=())
    expected_final_alignment_history = ResultSummary(
        shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.11899644)

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        attention_mechanism_depth=9,
        alignment_history=True,
        expected_final_alignment_history=expected_final_alignment_history,
        create_attention_kwargs=create_attention_kwargs)
Exemplo n.º 14
0
    def testBahndahauNormalized(self):
        create_attention_mechanism = functools.partial(
            wrapper.BahdanauAttention,
            normalize=True,
            attention_r_initializer=2.0)

        array = np.array
        float32 = np.float32
        int32 = np.int32

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=array(
                [[[
                    1.72670335e-02, -5.83671592e-03, 6.38638902e-03,
                    -8.11776379e-04, 1.12681929e-03, -1.24236047e-02
                ],
                  [
                      1.75918192e-02, -5.73426578e-03, 6.29768707e-03,
                      -8.63141613e-04, 2.03352375e-03, -1.21420780e-02
                  ],
                  [
                      1.72424167e-02, -5.66471322e-03, 6.63427915e-03,
                      -6.23903936e-04, 1.68706616e-03, -1.22524602e-02
                  ]],
                 [[
                     1.79958157e-02, -9.80986748e-03, 4.73218597e-03,
                     -3.89962713e-03, 1.41502675e-02, -1.48344040e-02
                 ],
                  [
                      1.82184577e-02, -9.88379307e-03, 4.90130857e-03,
                      -3.91892251e-03, 1.36479288e-02, -1.53291579e-02
                  ],
                  [
                      1.83001235e-02, -1.00617753e-02, 4.97077405e-03,
                      -3.94908339e-03, 1.37211196e-02, -1.52311027e-02
                  ]],
                 [[
                     7.93476030e-03, -8.46967567e-03, -7.16930721e-04,
                     4.37953044e-04, 1.04503892e-03, -1.82424393e-02
                 ],
                  [
                      7.90629163e-03, -8.48819874e-03, -5.57833235e-04,
                      5.02390554e-04, 6.79406337e-04, -1.84837580e-02
                  ],
                  [
                      8.14734399e-03, -8.23053624e-03, -5.92814526e-04,
                      4.16347990e-04, 1.29250437e-03, -1.84548404e-02
                  ]],
                 [[
                     1.21026095e-02, -1.26739489e-02, 1.78718648e-04,
                     2.68748170e-03, 7.80996867e-03, -9.69076063e-04
                 ],
                  [
                      1.17978491e-02, -1.32678337e-02, 6.00410858e-05,
                      2.66301399e-03, 7.00691342e-03, -1.10030361e-03
                  ],
                  [
                      1.15651665e-02, -1.30795036e-02, -2.74205930e-04,
                      2.48012133e-03, 6.94250735e-03, -8.47495161e-04
                  ]],
                 [[
                     1.02377674e-02, -8.72955937e-03, 1.22555892e-03,
                     2.03830865e-03, 8.93574394e-03, -7.28237582e-03
                 ],
                  [
                      1.05115287e-02, -8.92531779e-03, 1.14568521e-03,
                      1.91635895e-03, 8.94328393e-03, -7.39541650e-03
                  ],
                  [
                      1.07398070e-02, -8.56867433e-03, 1.52354129e-03,
                      2.06834078e-03, 9.36511997e-03, -7.64556089e-03
                  ]]],
                dtype=float32),
            sample_id=array(
                [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
                dtype=int32))

        expected_final_state = wrapper.AttentionWrapperState(
            time=3,
            attention_history=(),
            cell_state=core_rnn_cell.LSTMStateTuple(
                c=array([[
                    -0.02209264, -0.00794879, -0.00157153, 0.01614309,
                    -0.01383773, -0.00750943, -0.00824213, -0.01210296,
                    0.01794949
                ],
                         [
                             0.01726926, -0.01418139, -0.0040099, 0.0319339,
                             -0.03545783, -0.02142831, -0.00609501,
                             -0.00195033, -0.01938949
                         ],
                         [
                             -0.01159083, 0.0087524, -0.01639001, -0.01400012,
                             0.01342422, -0.01041037, 0.00620991, -0.00960796,
                             -0.00650131
                         ],
                         [
                             -0.04763237, -0.01192762, -0.00019377, 0.04103839,
                             -0.00138058, 0.02126443, -0.02793917, -0.05467755,
                             -0.02912025
                         ],
                         [
                             0.02241185, -0.00141741, 0.01911988, 0.00547728,
                             -0.01280068, -0.00307024, -0.00494239, 0.02169247,
                             0.01631995
                         ]],
                        dtype=float32),
                h=array([[
                    -1.10821165e-02, -3.92766716e-03, -7.99638336e-04,
                    7.92923011e-03, -7.04019284e-03, -3.77124036e-03,
                    -4.19876305e-03, -6.17261464e-03, 8.95325281e-03
                ],
                         [
                             8.60597286e-03, -7.16368994e-03, -1.94644753e-03,
                             1.62479617e-02, -1.76739115e-02, -1.06403306e-02,
                             -3.01484042e-03, -9.74688213e-04, -9.96260438e-03
                         ],
                         [
                             -5.78098884e-03, 4.48751403e-03, -8.12216662e-03,
                             -6.94991415e-03, 6.72604749e-03, -5.10144979e-03,
                             3.08637507e-03, -4.71517537e-03, -3.20256175e-03
                         ],
                         [
                             -2.38018110e-02, -5.89307398e-03, -9.74484938e-05,
                             2.01694984e-02, -6.82370039e-04, 1.07099237e-02,
                             -1.42087601e-02, -2.70793457e-02, -1.44684138e-02
                         ],
                         [
                             1.11825848e-02, -6.99267141e-04, 9.82748345e-03,
                             2.74566701e-03, -6.56377291e-03, -1.53681310e-03,
                             -2.48806458e-03, 1.10462429e-02, 7.97568541e-03
                         ]],
                        dtype=float32)),
            attention=array([[
                0.01724242, -0.00566471, 0.00663428, -0.0006239, 0.00168707,
                -0.01225246
            ],
                             [
                                 0.01830012, -0.01006178, 0.00497077,
                                 -0.00394908, 0.01372112, -0.0152311
                             ],
                             [
                                 0.00814734, -0.00823054, -0.00059281,
                                 0.00041635, 0.0012925, -0.01845484
                             ],
                             [
                                 0.01156517, -0.0130795, -0.00027421,
                                 0.00248012, 0.00694251, -0.0008475
                             ],
                             [
                                 0.01073981, -0.00856867, 0.00152354,
                                 0.00206834, 0.00936512, -0.00764556
                             ]],
                            dtype=float32))

        self._testWithAttention(create_attention_mechanism,
                                expected_final_output, expected_final_state)
Exemplo n.º 15
0
  def call(self, inputs, state):
    """Perform a step of attention-wrapped RNN.

    - Step 1: Mix the `inputs` and previous step's `attention` output via
      `cell_input_fn`.
    - Step 2: Call the wrapped `cell` with this input and its previous state.
    - Step 3: Score the cell's output with `attention_mechanism`.
    - Step 4: Calculate the alignments by passing the score through the
      `normalizer`.
    - Step 5: Calculate the context vector as the inner product between the
      alignments and the attention_mechanism's values (memory).
    - Step 6: Calculate the attention output by concatenating the cell output
      and context through the attention layer (a linear layer with
      `attention_layer_size` outputs).

    Args:
      inputs: (Possibly nested tuple of) Tensor, the input at this time step.
      state: An instance of `AttentionWrapperState` containing
        tensors from the previous time step.

    Returns:
      A tuple `(attention_or_cell_output, next_state)`, where:

      - `attention_or_cell_output` depending on `output_attention`.
      - `next_state` is an instance of `AttentionWrapperState`
         containing the state calculated at this time step.

    Raises:
      TypeError: If `state` is not an instance of `AttentionWrapperState`.
    """
    if not isinstance(state, attention_wrapper.AttentionWrapperState):
      raise TypeError("Expected state to be instance of AttentionWrapperState. "
                      "Received type %s instead."  % type(state))

    # Step 1: Calculate the true inputs to the cell based on the
    # previous attention value.
    cell_inputs = self._cell_input_fn(inputs, state.attention)
    cell_state = state.cell_state
    cell_output, next_cell_state = self._cell(cell_inputs, cell_state)

    cell_batch_size = (
        cell_output.shape[0].value or array_ops.shape(cell_output)[0])
    error_message = "Memory batch size must be equal to 1."
    with ops.control_dependencies(
        self._batch_size_checks(cell_batch_size, error_message)):
      cell_output = array_ops.identity(
          cell_output, name="checked_cell_output")

    if self._is_multi:
      previous_attention_state = state.attention_state
      previous_alignment_history = state.alignment_history
    else:
      previous_attention_state = [state.attention_state]
      previous_alignment_history = [state.alignment_history]

    all_alignments = []
    all_attentions = []
    all_attention_states = []
    maybe_all_histories = []
    for i, attention_mechanism in enumerate(self._attention_mechanisms):
      attention, alignments, next_attention_state = _compute_attention(
          attention_mechanism, cell_output, previous_attention_state[i],
          self._attention_layers[i] if self._attention_layers else None)
      alignment_history = previous_alignment_history[i].write(
          state.time, alignments) if self._alignment_history else ()

      all_attention_states.append(next_attention_state)
      all_alignments.append(alignments)
      all_attentions.append(attention)
      maybe_all_histories.append(alignment_history)

    attention = array_ops.concat(all_attentions, 1)
    next_state = attention_wrapper.AttentionWrapperState(
        time=state.time + 1,
        cell_state=next_cell_state,
        attention=attention,
        attention_state=self._item_or_tuple(all_attention_states),
        alignments=self._item_or_tuple(all_alignments),
        alignment_history=self._item_or_tuple(maybe_all_histories))

    if self._output_attention:
      return attention, next_state
    else:
      return cell_output, next_state
Exemplo n.º 16
0
    def testLuongNormalized(self):
        create_attention_mechanism = functools.partial(
            wrapper.LuongAttention,
            normalize=True,
            attention_r_initializer=2.0)

        array = np.array
        float32 = np.float32
        int32 = np.int32

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=array(
                [[[
                    1.23956744e-02, -6.88115368e-03, 3.15234554e-03,
                    -1.97300944e-03, 4.79680905e-03, -1.38076628e-02
                ],
                  [
                      1.28376717e-02, -6.78718928e-03, 3.07988771e-03,
                      -2.03956687e-03, 5.68403490e-03, -1.35601182e-02
                  ],
                  [
                      1.23463338e-02, -6.76322030e-03, 3.28891934e-03,
                      -1.86874042e-03, 5.47897862e-03, -1.37654068e-02
                  ]],
                 [[
                     1.54412268e-02, -1.07613346e-02, 4.43824846e-03,
                     -8.81063985e-04, 1.26828086e-02, -1.21067995e-02
                 ],
                  [
                      1.57206059e-02, -1.08218864e-02, 4.61952807e-03,
                      -9.61483689e-04, 1.22140013e-02, -1.26614980e-02
                  ],
                  [
                      1.57821011e-02, -1.09842420e-02, 4.66934917e-03,
                      -9.85997496e-04, 1.22719472e-02, -1.25438003e-02
                  ]],
                 [[
                     9.27361846e-03, -9.66077764e-03, -9.69522633e-04,
                     1.48308463e-05, 3.88664147e-03, -1.64083000e-02
                 ],
                  [
                      9.26287938e-03, -9.74234194e-03, -8.32488062e-04,
                      5.83778601e-05, 3.52663640e-03, -1.66827720e-02
                  ],
                  [
                      9.50474478e-03, -9.49789397e-03, -8.71829456e-04,
                      -3.09986062e-05, 4.13423358e-03, -1.66635048e-02
                  ]],
                 [[
                     1.21398102e-02, -1.27454493e-02, 1.57688977e-04,
                     2.70034792e-03, 7.79653806e-03, -8.36936757e-04
                 ],
                  [
                      1.18234595e-02, -1.33170560e-02, 4.55579720e-05,
                      2.67185434e-03, 6.99766818e-03, -1.00935437e-03
                  ],
                  [
                      1.16009805e-02, -1.31483339e-02, -2.94458936e-04,
                      2.49248254e-03, 6.92958105e-03, -7.20315147e-04
                  ]],
                 [[
                     1.02377674e-02, -8.72955937e-03, 1.22555892e-03,
                     2.03830865e-03, 8.93574394e-03, -7.28237582e-03
                 ],
                  [
                      1.05115287e-02, -8.92531779e-03, 1.14568521e-03,
                      1.91635895e-03, 8.94328393e-03, -7.39541650e-03
                  ],
                  [
                      1.07398070e-02, -8.56867433e-03, 1.52354129e-03,
                      2.06834078e-03, 9.36511997e-03, -7.64556089e-03
                  ]]],
                dtype=float32),
            sample_id=array(
                [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
                dtype=int32))
        expected_final_state = wrapper.AttentionWrapperState(
            time=3,
            attention_history=(),
            cell_state=core_rnn_cell.LSTMStateTuple(
                c=array([[
                    -0.02204949, -0.00805957, -0.001603, 0.01609283,
                    -0.01380462, -0.0074945, -0.00816895, -0.01210009,
                    0.01795324
                ],
                         [
                             0.01727016, -0.01420708, -0.00399973, 0.03195432,
                             -0.03547529, -0.02138673, -0.00610332,
                             -0.00191565, -0.01937822
                         ],
                         [
                             -0.01160676, 0.00876512, -0.01641791, -0.01400807,
                             0.01347767, -0.01036341, 0.00627499, -0.00963627,
                             -0.00650573
                         ],
                         [
                             -0.04763342, -0.01192671, -0.00019402, 0.04103871,
                             -0.00138017, 0.02126611, -0.02793773, -0.05467714,
                             -0.02912043
                         ],
                         [
                             0.02241185, -0.00141741, 0.01911988, 0.00547728,
                             -0.01280068, -0.00307024, -0.00494239, 0.02169247,
                             0.01631995
                         ]],
                        dtype=float32),
                h=array([[
                    -1.10610286e-02, -3.98253463e-03, -8.15684092e-04,
                    7.90454168e-03, -7.02364743e-03, -3.76377185e-03,
                    -4.16135695e-03, -6.17104582e-03, 8.95532966e-03
                ],
                         [
                             8.60653073e-03, -7.17685232e-03, -1.94147974e-03,
                             1.62585936e-02, -1.76823437e-02, -1.06195193e-02,
                             -3.01911240e-03, -9.57308919e-04, -9.95720550e-03
                         ],
                         [
                             -5.78888878e-03, 4.49400023e-03, -8.13617278e-03,
                             -6.95386063e-03, 6.75271638e-03, -5.07823005e-03,
                             3.11873178e-03, -4.72912844e-03, -3.20472987e-03
                         ],
                         [
                             -2.38023344e-02, -5.89262368e-03, -9.75721487e-05,
                             2.01696623e-02, -6.82163402e-04, 1.07107637e-02,
                             -1.42080421e-02, -2.70791352e-02, -1.44685050e-02
                         ],
                         [
                             1.11825848e-02, -6.99267141e-04, 9.82748345e-03,
                             2.74566701e-03, -6.56377291e-03, -1.53681310e-03,
                             -2.48806458e-03, 1.10462429e-02, 7.97568541e-03
                         ]],
                        dtype=float32)),
            attention=array(
                [[
                    1.23463338e-02, -6.76322030e-03, 3.28891934e-03,
                    -1.86874042e-03, 5.47897862e-03, -1.37654068e-02
                ],
                 [
                     1.57821011e-02, -1.09842420e-02, 4.66934917e-03,
                     -9.85997496e-04, 1.22719472e-02, -1.25438003e-02
                 ],
                 [
                     9.50474478e-03, -9.49789397e-03, -8.71829456e-04,
                     -3.09986062e-05, 4.13423358e-03, -1.66635048e-02
                 ],
                 [
                     1.16009805e-02, -1.31483339e-02, -2.94458936e-04,
                     2.49248254e-03, 6.92958105e-03, -7.20315147e-04
                 ],
                 [
                     1.07398070e-02, -8.56867433e-03, 1.52354129e-03,
                     2.06834078e-03, 9.36511997e-03, -7.64556089e-03
                 ]],
                dtype=float32))
        self._testWithAttention(create_attention_mechanism,
                                expected_final_output,
                                expected_final_state,
                                attention_mechanism_depth=9)
Exemplo n.º 17
0
    def testLuongNotNormalized(self):
        create_attention_mechanism = wrapper.LuongAttention

        array = np.array
        float32 = np.float32
        int32 = np.int32

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=array(
                [[[
                    1.23641128e-02, -6.82715839e-03, 3.24165262e-03,
                    -1.90772023e-03, 4.69654519e-03, -1.37025211e-02
                ],
                  [
                      1.29463980e-02, -6.79699238e-03, 3.10124992e-03,
                      -2.02869414e-03, 5.66399656e-03, -1.35517996e-02
                  ],
                  [
                      1.22659411e-02, -6.81970268e-03, 3.15135531e-03,
                      -1.96937821e-03, 5.62768336e-03, -1.39173865e-02
                  ]],
                 [[
                     1.53944232e-02, -1.07725551e-02, 4.42822604e-03,
                     -8.30623554e-04, 1.26549732e-02, -1.20573286e-02
                 ],
                  [
                      1.57453734e-02, -1.08157266e-02, 4.62466478e-03,
                      -9.88351414e-04, 1.22286947e-02, -1.26876952e-02
                  ],
                  [
                      1.57857724e-02, -1.09536834e-02, 4.64798324e-03,
                      -1.01319887e-03, 1.22695938e-02, -1.25500849e-02
                  ]],
                 [[
                     9.23123397e-03, -9.42669343e-03, -9.09919385e-04,
                     6.09827694e-05, 3.90436035e-03, -1.63374804e-02
                 ],
                  [
                      9.22935922e-03, -9.57853813e-03, -7.92966573e-04,
                      8.89014918e-05, 3.52671882e-03, -1.66499857e-02
                  ],
                  [
                      9.49526206e-03, -9.39475093e-03, -8.49372707e-04,
                      -1.72815053e-05, 4.16132808e-03, -1.66336838e-02
                  ]],
                 [[
                     1.21248290e-02, -1.27166547e-02, 1.66158192e-04,
                     2.69516627e-03, 7.80194718e-03, -8.90152063e-04
                 ],
                  [
                      1.17861275e-02, -1.32453050e-02, 6.66640699e-05,
                      2.65894993e-03, 7.01114535e-03, -1.14195189e-03
                  ],
                  [
                      1.15833860e-02, -1.31145213e-02, -2.84505659e-04,
                      2.48642010e-03, 6.93593081e-03, -7.82784075e-04
                  ]],
                 [[
                     1.02377674e-02, -8.72955937e-03, 1.22555892e-03,
                     2.03830865e-03, 8.93574394e-03, -7.28237582e-03
                 ],
                  [
                      1.05115287e-02, -8.92531779e-03, 1.14568521e-03,
                      1.91635895e-03, 8.94328393e-03, -7.39541650e-03
                  ],
                  [
                      1.07398070e-02, -8.56867433e-03, 1.52354129e-03,
                      2.06834078e-03, 9.36511997e-03, -7.64556089e-03
                  ]]],
                dtype=float32),
            sample_id=array(
                [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
                dtype=int32))
        expected_final_state = wrapper.AttentionWrapperState(
            time=3,
            attention_history=(),
            cell_state=core_rnn_cell.LSTMStateTuple(
                c=array([[
                    -0.02204997, -0.00805805, -0.00160245, 0.01609369,
                    -0.01380494, -0.00749439, -0.00817, -0.01209992, 0.01795316
                ],
                         [
                             0.01727016, -0.01420713, -0.00399972, 0.03195436,
                             -0.03547532, -0.02138666, -0.00610335,
                             -0.00191557, -0.01937821
                         ],
                         [
                             -0.01160429, 0.00876595, -0.01641685, -0.01400784,
                             0.01348004, -0.01036458, 0.00627241, -0.00963544,
                             -0.00650568
                         ],
                         [
                             -0.04763246, -0.01192755, -0.00019379, 0.04103841,
                             -0.00138055, 0.02126456, -0.02793905, -0.0546775,
                             -0.02912027
                         ],
                         [
                             0.02241185, -0.00141741, 0.01911988, 0.00547728,
                             -0.01280068, -0.00307024, -0.00494239, 0.02169247,
                             0.01631995
                         ]],
                        dtype=float32),
                h=array([[
                    -1.10612623e-02, -3.98178305e-03, -8.15406092e-04,
                    7.90496264e-03, -7.02379830e-03, -3.76371504e-03,
                    -4.16189339e-03, -6.17096573e-03, 8.95528216e-03
                ],
                         [
                             8.60652886e-03, -7.17687514e-03, -1.94147555e-03,
                             1.62586085e-02, -1.76823605e-02, -1.06194830e-02,
                             -3.01912241e-03, -9.57269047e-04, -9.95719433e-03
                         ],
                         [
                             -5.78764686e-03, 4.49441886e-03, -8.13564472e-03,
                             -6.95375400e-03, 6.75391173e-03, -5.07880514e-03,
                             3.11744539e-03, -4.72871540e-03, -3.20470310e-03
                         ],
                         [
                             -2.38018595e-02, -5.89303859e-03, -9.74571449e-05,
                             2.01695058e-02, -6.82353624e-04, 1.07099945e-02,
                             -1.42086931e-02, -2.70793252e-02, -1.44684194e-02
                         ],
                         [
                             1.11825848e-02, -6.99267141e-04, 9.82748345e-03,
                             2.74566701e-03, -6.56377291e-03, -1.53681310e-03,
                             -2.48806458e-03, 1.10462429e-02, 7.97568541e-03
                         ]],
                        dtype=float32)),
            attention=array(
                [[
                    1.22659411e-02, -6.81970268e-03, 3.15135531e-03,
                    -1.96937821e-03, 5.62768336e-03, -1.39173865e-02
                ],
                 [
                     1.57857724e-02, -1.09536834e-02, 4.64798324e-03,
                     -1.01319887e-03, 1.22695938e-02, -1.25500849e-02
                 ],
                 [
                     9.49526206e-03, -9.39475093e-03, -8.49372707e-04,
                     -1.72815053e-05, 4.16132808e-03, -1.66336838e-02
                 ],
                 [
                     1.15833860e-02, -1.31145213e-02, -2.84505659e-04,
                     2.48642010e-03, 6.93593081e-03, -7.82784075e-04
                 ],
                 [
                     1.07398070e-02, -8.56867433e-03, 1.52354129e-03,
                     2.06834078e-03, 9.36511997e-03, -7.64556089e-03
                 ]],
                dtype=float32))

        self._testWithAttention(create_attention_mechanism,
                                expected_final_output,
                                expected_final_state,
                                attention_mechanism_depth=9)
Exemplo n.º 18
0
 def testAttentionWrapperState(self):
     num_fields = len(wrapper.AttentionWrapperState._fields)  # pylint: disable=protected-access
     state = wrapper.AttentionWrapperState(*([None] * num_fields))
     new_state = state.clone(time=1)
     self.assertEqual(state.time, None)
     self.assertEqual(new_state.time, 1)
Exemplo n.º 19
0
    def call(self, inputs, prev_state):
        """
        Perform a step of attention-wrapped RNN.

        This method assumes `inputs` is the word embedding vector.

        This method overrides the original `call()` method.
        """
        _attn_mech = self._attention_mechanisms[0]
        attn_size = _attn_mech._num_units
        batch_size = _attn_mech.batch_size
        dtype = inputs.dtype

        # Step 1: Calculate the true inputs to the cell based on the
        # previous attention value.
        # `_cell_input_fn` defaults to
        # `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`
        _dprint('{}: prev_state received by call(): {}'.format(
            self.__class__.__name__, prev_state))
        cell_inputs = self._cell_input_fn(inputs, prev_state.attention)
        prev_cell_state = prev_state.cell_state
        cell_output, curr_cell_state = self._cell(cell_inputs, prev_cell_state)

        cell_batch_size = (cell_output.shape[0].value
                           or tf.shape(cell_output)[0])
        error_message = (
            "When applying AttentionWrapper %s: " % self.name +
            "Non-matching batch sizes between the memory (encoder output) "
            "and the query (decoder output). Are you using the "
            "BeamSearchDecoder? You may need to tile your memory input via "
            "the tf.contrib.seq2seq.tile_batch function with argument "
            "multiple=beam_width.")
        with tf.control_dependencies([
                tf.assert_equal(cell_batch_size,
                                _attn_mech.batch_size,
                                message=error_message)
        ]):
            cell_output = tf.identity(cell_output, name="checked_cell_output")

        dtype = cell_output.dtype
        assert len(self._attention_mechanisms) == 1
        _attn_mech = self._attention_mechanisms[0]
        alignments, attention_state = _attn_mech(cell_output, state=None)

        if self._alignments_keep_prob < 1.:
            alignments = tf.contrib.layers.dropout(
                inputs=alignments,
                keep_prob=self._alignments_keep_prob,
                noise_shape=None,
                is_training=True)

        if len(_shape(alignments)) == 3:
            # Multi-head attention
            # Expand from [batch_size, num_heads, memory_time] to [batch_size, num_heads, 1, memory_time]
            expanded_alignments = tf.expand_dims(alignments, 2)
            # attention_mechanism.values shape is
            #     [batch_size, num_heads, memory_time, num_units / num_heads]
            # the batched matmul is over memory_time, so the output shape is
            #     [batch_size, num_heads, 1, num_units / num_heads].
            # we then combine the heads
            #     [batch_size, 1, attention_mechanism.num_units]
            attention_mechanism_values = _attn_mech.values_split
            context = tf.matmul(expanded_alignments,
                                attention_mechanism_values)
            attention = tf.squeeze(combine_heads(context), [1])
        else:
            # Expand from [batch_size, memory_time] to [batch_size, 1, memory_time]
            expanded_alignments = tf.expand_dims(alignments, 1)
            # Context is the inner product of alignments and values along the
            # memory time dimension.
            # alignments shape is
            #     [batch_size, 1, memory_time]
            # attention_mechanism.values shape is
            #     [batch_size, memory_time, attention_mechanism.num_units]
            # the batched matmul is over memory_time, so the output shape is
            #     [batch_size, 1, attention_mechanism.num_units].
            # we then squeeze out the singleton dim.
            attention_mechanism_values = _attn_mech.values
            context = tf.matmul(expanded_alignments,
                                attention_mechanism_values)
            attention = tf.squeeze(context, [1])

        # Context projection
        if self._context_layer:
            # noinspection PyCallingNonCallable
            attention = self._dense_layer(name='a_layer',
                                          units=_attn_mech._num_units,
                                          use_bias=False,
                                          activation=None,
                                          dtype=dtype,
                                          **self._mask_params)(attention)

        if self._alignment_history:
            alignments = tf.reshape(alignments, [cell_batch_size, -1])
            alignment_history = prev_state.alignment_history.write(
                prev_state.time, alignments)
        else:
            alignment_history = ()

        curr_state = attention_wrapper.AttentionWrapperState(
            time=prev_state.time + 1,
            cell_state=curr_cell_state,
            attention=attention,
            attention_state=alignments,
            alignments=alignments,
            alignment_history=alignment_history)
        return cell_output, curr_state
Exemplo n.º 20
0
    def testBahndahauNotNormalized(self):
        create_attention_mechanism = wrapper.BahdanauAttention

        array = np.array
        float32 = np.float32
        int32 = np.int32

        expected_final_outputs = basic_decoder.BasicDecoderOutput(
            rnn_output=array(
                [[[
                    1.25166783e-02, -6.88887993e-03, 3.17239435e-03,
                    -1.98234897e-03, 4.77387803e-03, -1.38330357e-02
                ],
                  [
                      1.28883058e-02, -6.76271692e-03, 3.13419267e-03,
                      -2.02183682e-03, 5.62057737e-03, -1.35373026e-02
                  ],
                  [
                      1.24917831e-02, -6.71574520e-03, 3.42238229e-03,
                      -1.79501204e-03, 5.33161033e-03, -1.36620644e-02
                  ]],
                 [[
                     1.55150667e-02, -1.07274549e-02, 4.44198400e-03,
                     -9.73310322e-04, 1.27242506e-02, -1.21861566e-02
                 ],
                  [
                      1.57585666e-02, -1.07965544e-02, 4.61554807e-03,
                      -1.01510016e-03, 1.22341057e-02, -1.27029382e-02
                  ],
                  [
                      1.58304181e-02, -1.09712025e-02, 4.67861444e-03,
                      -1.03920139e-03, 1.23004699e-02, -1.25949886e-02
                  ]],
                 [[
                     9.26700700e-03, -9.75431874e-03, -9.95740294e-04,
                     -1.27463136e-06, 3.81659716e-03, -1.64887272e-02
                 ],
                  [
                      9.25191958e-03, -9.80092678e-03, -8.48566880e-04,
                      5.02091134e-05, 3.46567202e-03, -1.67435352e-02
                  ],
                  [
                      9.48173273e-03, -9.52653307e-03, -8.79382715e-04,
                      -3.07094306e-05, 4.05955408e-03, -1.67226996e-02
                  ]],
                 [[
                     1.21462569e-02, -1.27578378e-02, 1.54045003e-04,
                     2.70257704e-03, 7.79421115e-03, -8.14041123e-04
                 ],
                  [
                      1.18412934e-02, -1.33513296e-02, 3.54760559e-05,
                      2.67801876e-03, 6.99122995e-03, -9.46014654e-04
                  ],
                  [
                      1.16087487e-02, -1.31632648e-02, -2.98853614e-04,
                      2.49515846e-03, 6.92677684e-03, -6.92734495e-04
                  ]],
                 [[
                     1.02377674e-02, -8.72955937e-03, 1.22555892e-03,
                     2.03830865e-03, 8.93574394e-03, -7.28237582e-03
                 ],
                  [
                      1.05115287e-02, -8.92531779e-03, 1.14568521e-03,
                      1.91635895e-03, 8.94328393e-03, -7.39541650e-03
                  ],
                  [
                      1.07398070e-02, -8.56867433e-03, 1.52354129e-03,
                      2.06834078e-03, 9.36511997e-03, -7.64556089e-03
                  ]]],
                dtype=float32),
            sample_id=array(
                [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
                dtype=int32))

        expected_final_state = wrapper.AttentionWrapperState(
            time=3,
            attention_history=(),
            cell_state=core_rnn_cell.LSTMStateTuple(
                c=array([[
                    -0.0220502, -0.008058, -0.00160266, 0.01609341,
                    -0.01380513, -0.00749483, -0.00816989, -0.01210028,
                    0.01795324
                ],
                         [
                             0.01727026, -0.0142065, -0.00399991, 0.03195379,
                             -0.03547479, -0.02138772, -0.00610318,
                             -0.00191625, -0.01937846
                         ],
                         [
                             -0.0116077, 0.00876439, -0.01641787, -0.01400803,
                             0.01347527, -0.01036386, 0.00627491, -0.0096361,
                             -0.00650565
                         ],
                         [
                             -0.04763387, -0.01192631, -0.00019412, 0.04103886,
                             -0.00137999, 0.02126684, -0.02793711, -0.05467696,
                             -0.02912051
                         ],
                         [
                             0.02241185, -0.00141741, 0.01911988, 0.00547728,
                             -0.01280068, -0.00307024, -0.00494239, 0.02169247,
                             0.01631995
                         ]],
                        dtype=float32),
                h=array([[
                    -1.10613741e-02, -3.98175791e-03, -8.15514475e-04,
                    7.90482666e-03, -7.02390168e-03, -3.76394135e-03,
                    -4.16183751e-03, -6.17114361e-03, 8.95532221e-03
                ],
                         [
                             8.60657450e-03, -7.17655150e-03, -1.94156705e-03,
                             1.62583217e-02, -1.76821016e-02, -1.06200138e-02,
                             -3.01904045e-03, -9.57608980e-04, -9.95732192e-03
                         ],
                         [
                             -5.78935863e-03, 4.49362956e-03, -8.13615043e-03,
                             -6.95384294e-03, 6.75151078e-03, -5.07845683e-03,
                             3.11869266e-03, -4.72904649e-03, -3.20469099e-03
                         ],
                         [
                             -2.38025561e-02, -5.89242764e-03, -9.76260417e-05,
                             2.01697368e-02, -6.82076614e-04, 1.07111251e-02,
                             -1.42077375e-02, -2.70790439e-02, -1.44685479e-02
                         ],
                         [
                             1.11825848e-02, -6.99267141e-04, 9.82748345e-03,
                             2.74566701e-03, -6.56377291e-03, -1.53681310e-03,
                             -2.48806458e-03, 1.10462429e-02, 7.97568541e-03
                         ]],
                        dtype=float32)),
            attention=array(
                [[
                    1.24917831e-02, -6.71574520e-03, 3.42238229e-03,
                    -1.79501204e-03, 5.33161033e-03, -1.36620644e-02
                ],
                 [
                     1.58304181e-02, -1.09712025e-02, 4.67861444e-03,
                     -1.03920139e-03, 1.23004699e-02, -1.25949886e-02
                 ],
                 [
                     9.48173273e-03, -9.52653307e-03, -8.79382715e-04,
                     -3.07094306e-05, 4.05955408e-03, -1.67226996e-02
                 ],
                 [
                     1.16087487e-02, -1.31632648e-02, -2.98853614e-04,
                     2.49515846e-03, 6.92677684e-03, -6.92734495e-04
                 ],
                 [
                     1.07398070e-02, -8.56867433e-03, 1.52354129e-03,
                     2.06834078e-03, 9.36511997e-03, -7.64556089e-03
                 ]],
                dtype=float32))
        self._testWithAttention(create_attention_mechanism,
                                expected_final_outputs,
                                expected_final_state,
                                attention_history=True)