def testLuongMonotonicScaled(self): create_attention_mechanism = wrapper.LuongMonotonicAttention create_attention_kwargs = {"scale": True} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.003664831), sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=3.06666666)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.54318606), ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=1.12592840) ], attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=0.059128221), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.05112994), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.05112994), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.06994973868) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=9, alignment_history=True, expected_final_alignment_history=expected_final_alignment_history, create_attention_kwargs=create_attention_kwargs)
def testLuongNotNormalized(self): create_attention_mechanism = wrapper.LuongAttention expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary(shape=(5, 3, 6), dtype=np.dtype("float32"), mean=-0.06124732), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=2.73333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.52021580), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=1.0964939) ], attention=ResultSummary(shape=(5, 6), dtype=np.dtype("float32"), mean=-0.0318060), time=3, alignments=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), attention_state=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), alignment_history=()) self._testWithAttention(create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=9)
def testBahdanauNormalized(self): create_attention_mechanism = wrapper.BahdanauAttention create_attention_kwargs = {"kernel_initializer": "ones", "normalize": True} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 6), dtype=np.dtype("float32"), mean=-0.008089137 ), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=2.8), ) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.49166861), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=1.01068615), ], attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=0.042427111 ), alignments=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.125 ), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.125 ), alignment_history=(), ) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, create_query_layer=True, create_attention_kwargs=create_attention_kwargs, )
def testNotUseAttentionLayer(self): create_attention_mechanism = wrapper.BahdanauAttention create_attention_kwargs = {"kernel_initializer": "ones"} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 10), dtype=np.dtype("float32"), mean=0.078317143 ), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=4.2), ) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.89382392), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=1.722382), ], attention=ResultSummary( shape=(5, 10), dtype=np.dtype("float32"), mean=0.026356646 ), alignments=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.125 ), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.125 ), alignment_history=(), ) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, attention_layer_size=None, create_query_layer=True, create_attention_kwargs=create_attention_kwargs, )
def test_luong_not_normalized(): set_random_state_for_tf_and_np() policy = tf.keras.mixed_precision.experimental.global_policy() create_attention_mechanism = wrapper.LuongAttention expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 6), dtype=policy.compute_dtype, mean=-0.06124732 ), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=2.73333333), ) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=policy.compute_dtype, mean=0.52021580), ResultSummary(shape=(5, 9), dtype=policy.compute_dtype, mean=1.0964939), ], attention=ResultSummary( shape=(5, 6), dtype=policy.compute_dtype, mean=-0.0318060 ), alignments=ResultSummary(shape=(5, 8), dtype=policy.compute_dtype, mean=0.125), attention_state=ResultSummary( shape=(5, 8), dtype=policy.compute_dtype, mean=0.125 ), alignment_history=(), ) _test_with_attention( create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=9, )
def test_luong_scaled(): set_random_state_for_tf_and_np() create_attention_mechanism = wrapper.LuongAttention create_attention_kwargs = {"scale": True} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 6), dtype=np.dtype("float32"), mean=-0.06124732 ), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=2.73333333), ) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.52021580), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=1.0964939), ], attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=-0.0318060 ), alignments=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.125 ), alignment_history=(), ) _test_with_attention( create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=9, create_attention_kwargs=create_attention_kwargs, )
def testLuongNotNormalized(self): create_attention_mechanism = wrapper.LuongAttention expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary(shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.05481226), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=3.13333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.38453412), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.5785929) ], attention=ResultSummary(shape=(5, 6), dtype=np.dtype("float32"), mean=0.16311775), time=3, alignments=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), attention_state=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), alignment_history=()) self._testWithAttention(create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=9)
def testBahdanauNotNormalized(self): create_attention_mechanism = wrapper.BahdanauAttention create_attention_kwargs = {"kernel_initializer": "ones"} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=-0.003204414), sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype(np.int32), mean=3.2)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary( shape=(5, 9), dtype=np.dtype(np.float32), mean=0.40868404), ResultSummary( shape=(5, 9), dtype=np.dtype(np.float32), mean=0.89017969) ], attention=ResultSummary( shape=(5, 6), dtype=np.dtype(np.float32), mean=0.041453815), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype(np.float32), mean=0.125) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, alignment_history=True, create_query_layer=True, expected_final_alignment_history=expected_final_alignment_history, create_attention_kwargs=create_attention_kwargs)
def testBahdanauMonotonicNotNormalized(self): create_attention_mechanism = wrapper.BahdanauMonotonicAttention create_attention_kwargs = {"kernel_initializer": "ones"} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 6), dtype=np.dtype("float32"), mean=-0.009921653), sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=3.13333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.44612807), ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.95786464) ], attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=0.038682378), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.09778417), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.09778417), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.10261579603) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, alignment_history=True, expected_final_alignment_history=expected_final_alignment_history, create_query_layer=True, create_attention_kwargs=create_attention_kwargs)
def testLuongMonotonicNotNormalized(self): self.skipTest( "Resolve https://github.com/tensorflow/addons/issues/781") create_attention_mechanism = wrapper.LuongMonotonicAttention expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.003664831), sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=3.06666666)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.54318606), ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=1.12592840) ], attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=0.059128221), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.05112994), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.05112994), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.06994973868) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=9, alignment_history=True, expected_final_alignment_history=expected_final_alignment_history)
def testBahdanauMonotonicNormalized(self): create_attention_mechanism = wrapper.BahdanauMonotonicAttention create_attention_kwargs = { "kernel_initializer": "ones", "normalize": True } expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary(shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.007140680), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=3.26666666), ) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.47012400), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=1.0249618), ], attention=ResultSummary(shape=(5, 6), dtype=np.dtype("float32"), mean=0.068432882), time=3, alignments=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.0615656), attention_state=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.0615656), alignment_history=(), ) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.07909643) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, alignment_history=True, expected_final_alignment_history=expected_final_alignment_history, create_query_layer=True, create_attention_kwargs=create_attention_kwargs, )
def test_bahdanau_not_normalized(): set_random_state_for_tf_and_np() create_attention_mechanism = wrapper.BahdanauAttention create_attention_kwargs = {"kernel_initializer": "ones"} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary(shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=-0.003204414), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype(np.int32), mean=3.2), ) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype(np.float32), mean=0.40868404), ResultSummary(shape=(5, 9), dtype=np.dtype(np.float32), mean=0.89017969), ], attention=ResultSummary(shape=(5, 6), dtype=np.dtype(np.float32), mean=0.041453815), alignments=ResultSummary(shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125), attention_state=ResultSummary(shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125), alignment_history=(), ) expected_final_alignment_history = ResultSummary(shape=(3, 5, 8), dtype=np.dtype( np.float32), mean=0.125) _test_with_attention( create_attention_mechanism, expected_final_output, expected_final_state, alignment_history=True, create_query_layer=True, expected_final_alignment_history=expected_final_alignment_history, create_attention_kwargs=create_attention_kwargs, )
def test_luong_monotonic_not_normalized(): set_random_state_for_tf_and_np() create_attention_mechanism = wrapper.LuongMonotonicAttention expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary(shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.003664831), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=3.06666666), ) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.54318606), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=1.12592840), ], attention=ResultSummary(shape=(5, 6), dtype=np.dtype("float32"), mean=0.059128221), alignments=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.05112994), attention_state=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.05112994), alignment_history=(), ) expected_final_alignment_history = ResultSummary(shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.06994973868) _test_with_attention( create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=9, alignment_history=True, expected_final_alignment_history=expected_final_alignment_history, )
def testBahdanauMonotonicNormalized(self): create_attention_mechanism = wrapper.BahdanauMonotonicAttention create_attention_kwargs = { "kernel_initializer": "ones", "normalize": True } expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary(shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.043294173), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=3.53333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.40034312), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.5925445) ], attention=ResultSummary(shape=(5, 6), dtype=np.dtype("float32"), mean=0.096119694), time=3, alignments=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.1211452), attention_state=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.1211452), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12258384) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, alignment_history=True, expected_final_alignment_history=expected_final_alignment_history, create_query_layer=True, create_attention_kwargs=create_attention_kwargs)
def testLuongMonotonicScaled(self): create_attention_mechanism = wrapper.LuongMonotonicAttention create_attention_kwargs = {"scale": True} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary(shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.027387079), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=3.13333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.32660431), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.52464348) ], attention=ResultSummary(shape=(5, 6), dtype=np.dtype("float32"), mean=0.089345723), time=3, alignments=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.11831035), attention_state=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.11831035), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12194442004) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=9, alignment_history=True, expected_final_alignment_history=expected_final_alignment_history, create_attention_kwargs=create_attention_kwargs)
def testBahdanauMonotonicNotNormalized(self): create_attention_mechanism = wrapper.BahdanauMonotonicAttention create_attention_kwargs = {"kernel_initializer": "ones"} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary(shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.041342419), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=3.53333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.33866978), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.46913195) ], attention=ResultSummary(shape=(5, 6), dtype=np.dtype("float32"), mean=0.092498459), time=3, alignments=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.12079944), attention_state=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.12079944), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.121448785067) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, alignment_history=True, expected_final_alignment_history=expected_final_alignment_history, create_query_layer=True, create_attention_kwargs=create_attention_kwargs)
def testBahdanauNormalized(self): create_attention_mechanism = wrapper.BahdanauAttention create_attention_kwargs = { "kernel_initializer": "ones", "normalize": True } expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary(shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.047594748), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=3.6)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.41311637), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.61683208) ], attention=ResultSummary(shape=(5, 6), dtype=np.dtype("float32"), mean=0.090581432), time=3, alignments=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), attention_state=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), alignment_history=()) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, create_query_layer=True, create_attention_kwargs=create_attention_kwargs)
def testNotUseAttentionLayer(self): create_attention_mechanism = wrapper.BahdanauAttention create_attention_kwargs = {"kernel_initializer": "ones"} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary(shape=(5, 3, 10), dtype=np.dtype("float32"), mean=0.072406612), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=3.86666666)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.61177742), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=1.032002) ], attention=ResultSummary(shape=(5, 10), dtype=np.dtype("float32"), mean=0.011346335), time=3, alignments=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), attention_state=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), alignment_history=()) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, attention_layer_size=None, create_query_layer=True, create_attention_kwargs=create_attention_kwargs)