コード例 #1
0
    def testLuongNotNormalized(self):
        create_attention_mechanism = wrapper.LuongAttentionV2

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(shape=(5, 3, 6),
                                     dtype=np.dtype("float32"),
                                     mean=0.05481226),
            sample_id=ResultSummary(shape=(5, 3),
                                    dtype=np.dtype("int32"),
                                    mean=3.13333333))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.38453412),
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.5785929)
            ],
            attention=ResultSummary(shape=(5, 6),
                                    dtype=np.dtype("float32"),
                                    mean=0.16311775),
            time=3,
            alignments=ResultSummary(shape=(5, 8),
                                     dtype=np.dtype("float32"),
                                     mean=0.125),
            attention_state=ResultSummary(shape=(5, 8),
                                          dtype=np.dtype("float32"),
                                          mean=0.125),
            alignment_history=())

        self._testWithAttention(create_attention_mechanism,
                                expected_final_output,
                                expected_final_state,
                                attention_mechanism_depth=9)
コード例 #2
0
    def testBahdanauMonotonicNormalized(self):
        create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2
        create_attention_kwargs = {
            "kernel_initializer": "ones",
            "normalize": True
        }
        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(shape=(5, 3, 6),
                                     dtype=np.dtype("float32"),
                                     mean=0.043294173),
            sample_id=ResultSummary(shape=(5, 3),
                                    dtype=np.dtype("int32"),
                                    mean=3.53333333))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.40034312),
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.5925445)
            ],
            attention=ResultSummary(shape=(5, 6),
                                    dtype=np.dtype("float32"),
                                    mean=0.096119694),
            time=3,
            alignments=ResultSummary(shape=(5, 8),
                                     dtype=np.dtype("float32"),
                                     mean=0.1211452),
            attention_state=ResultSummary(shape=(5, 8),
                                          dtype=np.dtype("float32"),
                                          mean=0.1211452),
            alignment_history=())
        expected_final_alignment_history = ResultSummary(
            shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12258384)

        self._testWithAttention(
            create_attention_mechanism,
            expected_final_output,
            expected_final_state,
            alignment_history=True,
            expected_final_alignment_history=expected_final_alignment_history,
            create_query_layer=True,
            create_attention_kwargs=create_attention_kwargs)
コード例 #3
0
    def testBahdanauNotNormalized(self):
        create_attention_mechanism = wrapper.BahdanauAttentionV2
        create_attention_kwargs = {"kernel_initializer": "ones"}
        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(shape=(5, 3, 6),
                                     dtype=np.dtype(np.float32),
                                     mean=0.051747426),
            sample_id=ResultSummary(shape=(5, 3),
                                    dtype=np.dtype(np.int32),
                                    mean=3.33333333))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype(np.float32),
                              mean=0.44189346),
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype(np.float32),
                              mean=0.65429491)
            ],
            attention=ResultSummary(shape=(5, 6),
                                    dtype=np.dtype(np.float32),
                                    mean=0.073610783),
            time=3,
            alignments=ResultSummary(shape=(5, 8),
                                     dtype=np.dtype(np.float32),
                                     mean=0.125),
            attention_state=ResultSummary(shape=(5, 8),
                                          dtype=np.dtype(np.float32),
                                          mean=0.125),
            alignment_history=())
        expected_final_alignment_history = ResultSummary(shape=(3, 5, 8),
                                                         dtype=np.dtype(
                                                             np.float32),
                                                         mean=0.125)

        self._testWithAttention(
            create_attention_mechanism,
            expected_final_output,
            expected_final_state,
            alignment_history=True,
            create_query_layer=True,
            expected_final_alignment_history=expected_final_alignment_history,
            create_attention_kwargs=create_attention_kwargs)
コード例 #4
0
    def testLuongMonotonicScaled(self):
        create_attention_mechanism = wrapper.LuongMonotonicAttentionV2
        create_attention_kwargs = {"scale": True}

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(shape=(5, 3, 6),
                                     dtype=np.dtype("float32"),
                                     mean=0.027387079),
            sample_id=ResultSummary(shape=(5, 3),
                                    dtype=np.dtype("int32"),
                                    mean=3.13333333))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.32660431),
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.52464348)
            ],
            attention=ResultSummary(shape=(5, 6),
                                    dtype=np.dtype("float32"),
                                    mean=0.089345723),
            time=3,
            alignments=ResultSummary(shape=(5, 8),
                                     dtype=np.dtype("float32"),
                                     mean=0.11831035),
            attention_state=ResultSummary(shape=(5, 8),
                                          dtype=np.dtype("float32"),
                                          mean=0.11831035),
            alignment_history=())
        expected_final_alignment_history = ResultSummary(
            shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12194442004)

        self._testWithAttention(
            create_attention_mechanism,
            expected_final_output,
            expected_final_state,
            attention_mechanism_depth=9,
            alignment_history=True,
            expected_final_alignment_history=expected_final_alignment_history,
            create_attention_kwargs=create_attention_kwargs)
コード例 #5
0
    def testBahdanauMonotonicNotNormalized(self):
        create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2
        create_attention_kwargs = {"kernel_initializer": "ones"}

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(shape=(5, 3, 6),
                                     dtype=np.dtype("float32"),
                                     mean=0.041342419),
            sample_id=ResultSummary(shape=(5, 3),
                                    dtype=np.dtype("int32"),
                                    mean=3.53333333))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.33866978),
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.46913195)
            ],
            attention=ResultSummary(shape=(5, 6),
                                    dtype=np.dtype("float32"),
                                    mean=0.092498459),
            time=3,
            alignments=ResultSummary(shape=(5, 8),
                                     dtype=np.dtype("float32"),
                                     mean=0.12079944),
            attention_state=ResultSummary(shape=(5, 8),
                                          dtype=np.dtype("float32"),
                                          mean=0.12079944),
            alignment_history=())
        expected_final_alignment_history = ResultSummary(
            shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.121448785067)

        self._testWithAttention(
            create_attention_mechanism,
            expected_final_output,
            expected_final_state,
            alignment_history=True,
            expected_final_alignment_history=expected_final_alignment_history,
            create_query_layer=True,
            create_attention_kwargs=create_attention_kwargs)
コード例 #6
0
    def testBahdanauNormalized(self):
        create_attention_mechanism = wrapper.BahdanauAttentionV2
        create_attention_kwargs = {
            "kernel_initializer": "ones",
            "normalize": True
        }

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(shape=(5, 3, 6),
                                     dtype=np.dtype("float32"),
                                     mean=0.047594748),
            sample_id=ResultSummary(shape=(5, 3),
                                    dtype=np.dtype("int32"),
                                    mean=3.6))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.41311637),
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.61683208)
            ],
            attention=ResultSummary(shape=(5, 6),
                                    dtype=np.dtype("float32"),
                                    mean=0.090581432),
            time=3,
            alignments=ResultSummary(shape=(5, 8),
                                     dtype=np.dtype("float32"),
                                     mean=0.125),
            attention_state=ResultSummary(shape=(5, 8),
                                          dtype=np.dtype("float32"),
                                          mean=0.125),
            alignment_history=())

        self._testWithAttention(
            create_attention_mechanism,
            expected_final_output,
            expected_final_state,
            create_query_layer=True,
            create_attention_kwargs=create_attention_kwargs)
コード例 #7
0
    def testNotUseAttentionLayer(self):
        create_attention_mechanism = wrapper.BahdanauAttentionV2
        create_attention_kwargs = {"kernel_initializer": "ones"}

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(shape=(5, 3, 10),
                                     dtype=np.dtype("float32"),
                                     mean=0.072406612),
            sample_id=ResultSummary(shape=(5, 3),
                                    dtype=np.dtype("int32"),
                                    mean=3.86666666))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.61177742),
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=1.032002)
            ],
            attention=ResultSummary(shape=(5, 10),
                                    dtype=np.dtype("float32"),
                                    mean=0.011346335),
            time=3,
            alignments=ResultSummary(shape=(5, 8),
                                     dtype=np.dtype("float32"),
                                     mean=0.125),
            attention_state=ResultSummary(shape=(5, 8),
                                          dtype=np.dtype("float32"),
                                          mean=0.125),
            alignment_history=())

        self._testWithAttention(
            create_attention_mechanism,
            expected_final_output,
            expected_final_state,
            attention_layer_size=None,
            create_query_layer=True,
            create_attention_kwargs=create_attention_kwargs)
コード例 #8
0
 def testAttentionWrapperState(self):
   num_fields = len(wrapper.AttentionWrapperState._fields)  # pylint: disable=protected-access
   state = wrapper.AttentionWrapperState(*([None] * num_fields))
   new_state = state.clone(time=1)
   self.assertEqual(state.time, None)
   self.assertEqual(new_state.time, 1)