def testLuongNotNormalized(self): create_attention_mechanism = wrapper.LuongAttentionV2 expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary(shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.05481226), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=3.13333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.38453412), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.5785929) ], attention=ResultSummary(shape=(5, 6), dtype=np.dtype("float32"), mean=0.16311775), time=3, alignments=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), attention_state=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), alignment_history=()) self._testWithAttention(create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=9)
def testBahdanauMonotonicNormalized(self): create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2 create_attention_kwargs = { "kernel_initializer": "ones", "normalize": True } expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary(shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.043294173), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=3.53333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.40034312), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.5925445) ], attention=ResultSummary(shape=(5, 6), dtype=np.dtype("float32"), mean=0.096119694), time=3, alignments=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.1211452), attention_state=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.1211452), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12258384) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, alignment_history=True, expected_final_alignment_history=expected_final_alignment_history, create_query_layer=True, create_attention_kwargs=create_attention_kwargs)
def testBahdanauNotNormalized(self): create_attention_mechanism = wrapper.BahdanauAttentionV2 create_attention_kwargs = {"kernel_initializer": "ones"} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary(shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=0.051747426), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype(np.int32), mean=3.33333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype(np.float32), mean=0.44189346), ResultSummary(shape=(5, 9), dtype=np.dtype(np.float32), mean=0.65429491) ], attention=ResultSummary(shape=(5, 6), dtype=np.dtype(np.float32), mean=0.073610783), time=3, alignments=ResultSummary(shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125), attention_state=ResultSummary(shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125), alignment_history=()) expected_final_alignment_history = ResultSummary(shape=(3, 5, 8), dtype=np.dtype( np.float32), mean=0.125) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, alignment_history=True, create_query_layer=True, expected_final_alignment_history=expected_final_alignment_history, create_attention_kwargs=create_attention_kwargs)
def testLuongMonotonicScaled(self): create_attention_mechanism = wrapper.LuongMonotonicAttentionV2 create_attention_kwargs = {"scale": True} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary(shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.027387079), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=3.13333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.32660431), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.52464348) ], attention=ResultSummary(shape=(5, 6), dtype=np.dtype("float32"), mean=0.089345723), time=3, alignments=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.11831035), attention_state=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.11831035), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12194442004) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=9, alignment_history=True, expected_final_alignment_history=expected_final_alignment_history, create_attention_kwargs=create_attention_kwargs)
def testBahdanauMonotonicNotNormalized(self): create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2 create_attention_kwargs = {"kernel_initializer": "ones"} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary(shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.041342419), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=3.53333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.33866978), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.46913195) ], attention=ResultSummary(shape=(5, 6), dtype=np.dtype("float32"), mean=0.092498459), time=3, alignments=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.12079944), attention_state=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.12079944), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.121448785067) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, alignment_history=True, expected_final_alignment_history=expected_final_alignment_history, create_query_layer=True, create_attention_kwargs=create_attention_kwargs)
def testBahdanauNormalized(self): create_attention_mechanism = wrapper.BahdanauAttentionV2 create_attention_kwargs = { "kernel_initializer": "ones", "normalize": True } expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary(shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.047594748), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=3.6)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.41311637), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.61683208) ], attention=ResultSummary(shape=(5, 6), dtype=np.dtype("float32"), mean=0.090581432), time=3, alignments=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), attention_state=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), alignment_history=()) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, create_query_layer=True, create_attention_kwargs=create_attention_kwargs)
def testNotUseAttentionLayer(self): create_attention_mechanism = wrapper.BahdanauAttentionV2 create_attention_kwargs = {"kernel_initializer": "ones"} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary(shape=(5, 3, 10), dtype=np.dtype("float32"), mean=0.072406612), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=3.86666666)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.61177742), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=1.032002) ], attention=ResultSummary(shape=(5, 10), dtype=np.dtype("float32"), mean=0.011346335), time=3, alignments=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), attention_state=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), alignment_history=()) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, attention_layer_size=None, create_query_layer=True, create_attention_kwargs=create_attention_kwargs)
def testAttentionWrapperState(self): num_fields = len(wrapper.AttentionWrapperState._fields) # pylint: disable=protected-access state = wrapper.AttentionWrapperState(*([None] * num_fields)) new_state = state.clone(time=1) self.assertEqual(state.time, None) self.assertEqual(new_state.time, 1)