def test_attention_wrapper_with_multiple_attention_mechanisms(): cell = tf.keras.layers.LSTMCell(5) mechanisms = [wrapper.LuongAttention(units=3), wrapper.LuongAttention(units=3)] # We simply test that the wrapper creation makes no error. wrapper.AttentionWrapper(cell, mechanisms, attention_layer_size=[4, 5]) wrapper.AttentionWrapper( cell, mechanisms, attention_layer=[tf.keras.layers.Dense(4), tf.keras.layers.Dense(5)], )
def testLuongScaledDType(self, dtype): # Test case for GitHub issue 18099 encoder_outputs = self.encoder_outputs.astype(dtype) decoder_inputs = self.decoder_inputs.astype(dtype) attention_mechanism = wrapper.LuongAttention( units=self.units, memory=encoder_outputs, memory_sequence_length=self.encoder_sequence_length, scale=True, dtype=dtype, ) cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid") cell = wrapper.AttentionWrapper(cell, attention_mechanism) sampler = sampler_py.TrainingSampler() my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler) final_outputs, final_state, _ = my_decoder( decoder_inputs, initial_state=cell.get_initial_state(batch_size=self.batch, dtype=dtype), sequence_length=self.decoder_sequence_length) self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) self.assertEqual(final_outputs.rnn_output.dtype, dtype) self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
def test_luong_scaled_dtype(dtype): dummy_data = DummyData2() # Test case for GitHub issue 18099 encoder_outputs = dummy_data.encoder_outputs.astype(dtype) decoder_inputs = dummy_data.decoder_inputs.astype(dtype) attention_mechanism = wrapper.LuongAttention( units=dummy_data.units, memory=encoder_outputs, memory_sequence_length=dummy_data.encoder_sequence_length, scale=True, dtype=dtype, ) cell = tf.keras.layers.LSTMCell(dummy_data.units, recurrent_activation="sigmoid", dtype=dtype) cell = wrapper.AttentionWrapper(cell, attention_mechanism, dtype=dtype) sampler = sampler_py.TrainingSampler() my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler, dtype=dtype) final_outputs, final_state, _ = my_decoder( decoder_inputs, initial_state=cell.get_initial_state(batch_size=dummy_data.batch, dtype=dtype), sequence_length=dummy_data.decoder_sequence_length, ) assert isinstance(final_outputs, basic_decoder.BasicDecoderOutput) assert final_outputs.rnn_output.dtype == dtype assert isinstance(final_state, wrapper.AttentionWrapperState)
def testBahdanauNormalizedDType(self, dtype): encoder_outputs = self.encoder_outputs.astype(dtype) decoder_inputs = self.decoder_inputs.astype(dtype) attention_mechanism = wrapper.BahdanauAttention( units=self.units, memory=encoder_outputs, memory_sequence_length=self.encoder_sequence_length, normalize=True, dtype=dtype, ) cell = tf.keras.layers.LSTMCell( self.units, recurrent_activation="sigmoid", dtype=dtype ) cell = wrapper.AttentionWrapper(cell, attention_mechanism, dtype=dtype) sampler = sampler_py.TrainingSampler() my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler, dtype=dtype) final_outputs, final_state, _ = my_decoder( decoder_inputs, initial_state=cell.get_initial_state(batch_size=self.batch, dtype=dtype), sequence_length=self.decoder_sequence_length, ) self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) self.assertEqual(final_outputs.rnn_output.dtype, dtype) self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
def test_custom_attention_layer(): dummy_data = DummyData2() attention_mechanism = wrapper.LuongAttention(dummy_data.units) cell = tf.keras.layers.LSTMCell(dummy_data.units) attention_layer = tf.keras.layers.Dense( dummy_data.units * 2, use_bias=False, activation=tf.math.tanh ) attention_wrapper = wrapper.AttentionWrapper( cell, attention_mechanism, attention_layer=attention_layer ) with pytest.raises(ValueError): # Should fail because the attention mechanism has not been # initialized. attention_wrapper.get_initial_state( batch_size=dummy_data.batch, dtype=tf.float32 ) attention_mechanism.setup_memory( dummy_data.encoder_outputs.astype(np.float32), memory_sequence_length=dummy_data.encoder_sequence_length, ) initial_state = attention_wrapper.get_initial_state( batch_size=dummy_data.batch, dtype=tf.float32 ) assert initial_state.attention.shape[-1] == dummy_data.units * 2 first_input = dummy_data.decoder_inputs[:, 0].astype(np.float32) output, _ = attention_wrapper(first_input, initial_state) assert output.shape[-1] == dummy_data.units * 2
def test_attention_wrapper_with_gru_cell(): mechanism = wrapper.LuongAttention(units=3) cell = tf.keras.layers.GRUCell(3) cell = wrapper.AttentionWrapper(cell, mechanism) memory = tf.ones([2, 5, 3]) inputs = tf.ones([2, 3]) mechanism.setup_memory(memory) initial_state = cell.get_initial_state(inputs=inputs) _, state = cell(inputs, initial_state) tf.nest.assert_same_structure(initial_state, state)
def test_basic_decoder_with_attention_wrapper(): units = 32 vocab_size = 1000 attention_mechanism = attention_wrapper.LuongAttention(units) cell = tf.keras.layers.LSTMCell(units) cell = attention_wrapper.AttentionWrapper(cell, attention_mechanism) output_layer = tf.keras.layers.Dense(vocab_size) sampler = sampler_py.TrainingSampler() # BasicDecoder should accept a non initialized AttentionWrapper. basic_decoder.BasicDecoder(cell, sampler, output_layer=output_layer)
def test_attention_state_with_keras_rnn(): # See https://github.com/tensorflow/addons/issues/1095. cell = tf.keras.layers.LSTMCell(8) mechanism = wrapper.LuongAttention(units=8, memory=tf.ones((2, 4, 8))) cell = wrapper.AttentionWrapper(cell=cell, attention_mechanism=mechanism) layer = tf.keras.layers.RNN(cell) _ = layer(inputs=tf.ones((2, 4, 8))) # Make sure the explicit initial_state also works. initial_state = cell.get_initial_state(batch_size=2, dtype=tf.float32) _ = layer(inputs=tf.ones((2, 4, 8)), initial_state=initial_state)
def test_attention_state_with_variable_length_input(self): cell = tf.keras.layers.LSTMCell(3) mechanism = wrapper.LuongAttention(units=3) cell = wrapper.AttentionWrapper(cell, mechanism) var_len = tf.random.uniform(shape=(), minval=2, maxval=10, dtype=tf.int32) data = tf.ones(shape=(var_len, var_len, 3)) mechanism.setup_memory(data) layer = tf.keras.layers.RNN(cell) _ = layer(data)
def testCustomAttentionLayer(self): attention_mechanism = wrapper.LuongAttention(self.units) cell = tf.keras.layers.LSTMCell(self.units) attention_layer = tf.keras.layers.Dense( self.units * 2, use_bias=False, activation=tf.math.tanh) attention_wrapper = wrapper.AttentionWrapper( cell, attention_mechanism, attention_layer=attention_layer) with self.assertRaises(ValueError): # Should fail because the attention mechanism has not been # initialized. attention_wrapper.get_initial_state( batch_size=self.batch, dtype=tf.float32) attention_mechanism.setup_memory( self.encoder_outputs.astype(np.float32), memory_sequence_length=self.encoder_sequence_length) initial_state = attention_wrapper.get_initial_state( batch_size=self.batch, dtype=tf.float32) self.assertEqual(initial_state.attention.shape[-1], self.units * 2) first_input = self.decoder_inputs[:, 0].astype(np.float32) output, next_state = attention_wrapper(first_input, initial_state) self.assertEqual(output.shape[-1], self.units * 2)
def _testWithMaybeMultiAttention(self, is_multi, create_attention_mechanisms, expected_final_output, expected_final_state, attention_mechanism_depths, alignment_history=False, expected_final_alignment_history=None, attention_layer_sizes=None, attention_layers=None, create_query_layer=False, create_memory_layer=True, create_attention_kwargs=None): # Allow is_multi to be True with a single mechanism to enable test for # passing in a single mechanism in a list. assert len(create_attention_mechanisms) == 1 or is_multi encoder_sequence_length = [3, 2, 3, 1, 1] decoder_sequence_length = [2, 0, 1, 2, 3] batch_size = 5 encoder_max_time = 8 decoder_max_time = 4 input_depth = 7 encoder_output_depth = 10 cell_depth = 9 create_attention_kwargs = create_attention_kwargs or {} if attention_layer_sizes is not None: # Compute sum of attention_layer_sizes. Use encoder_output_depth if # None. attention_depth = sum( attention_layer_size or encoder_output_depth for attention_layer_size in attention_layer_sizes) elif attention_layers is not None: # Compute sum of attention_layers output depth. attention_depth = sum( attention_layer.compute_output_shape( [batch_size, cell_depth + encoder_output_depth]).dims[-1].value for attention_layer in attention_layers) else: attention_depth = encoder_output_depth * len( create_attention_mechanisms) decoder_inputs = np.random.randn(batch_size, decoder_max_time, input_depth).astype(np.float32) encoder_outputs = np.random.randn(batch_size, encoder_max_time, encoder_output_depth).astype( np.float32) attention_mechanisms = [] for creator, depth in zip(create_attention_mechanisms, attention_mechanism_depths): # Create a memory layer with deterministic initializer to avoid # randomness in the test between graph and eager. if create_query_layer: create_attention_kwargs["query_layer"] = keras.layers.Dense( depth, kernel_initializer="ones", use_bias=False) if create_memory_layer: create_attention_kwargs["memory_layer"] = keras.layers.Dense( depth, kernel_initializer="ones", use_bias=False) attention_mechanisms.append( creator(units=depth, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, **create_attention_kwargs)) with self.cached_session(use_gpu=True): attention_layer_size = attention_layer_sizes attention_layer = attention_layers if not is_multi: if attention_layer_size is not None: attention_layer_size = attention_layer_size[0] if attention_layer is not None: attention_layer = attention_layer[0] cell = keras.layers.LSTMCell(cell_depth, recurrent_activation="sigmoid", kernel_initializer="ones", recurrent_initializer="ones") cell = wrapper.AttentionWrapper( cell, attention_mechanisms if is_multi else attention_mechanisms[0], attention_layer_size=attention_layer_size, alignment_history=alignment_history, attention_layer=attention_layer) if cell._attention_layers is not None: for layer in cell._attention_layers: layer.kernel_initializer = initializers.glorot_uniform( seed=1337) sampler = sampler_py.TrainingSampler() my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler) initial_state = cell.get_initial_state(dtype=tf.float32, batch_size=batch_size) final_outputs, final_state, _ = my_decoder( decoder_inputs, initial_state=initial_state, sequence_length=decoder_sequence_length) self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) self.assertIsInstance(final_state, wrapper.AttentionWrapperState) expected_time = (expected_final_state.time if tf.executing_eagerly() else None) self.assertEqual( (batch_size, expected_time, attention_depth), tuple(final_outputs.rnn_output.get_shape().as_list())) self.assertEqual( (batch_size, expected_time), tuple(final_outputs.sample_id.get_shape().as_list())) self.assertEqual( (batch_size, attention_depth), tuple(final_state.attention.get_shape().as_list())) self.assertEqual( (batch_size, cell_depth), tuple(final_state.cell_state[0].get_shape().as_list())) self.assertEqual( (batch_size, cell_depth), tuple(final_state.cell_state[1].get_shape().as_list())) if alignment_history: if is_multi: state_alignment_history = [] for history_array in final_state.alignment_history: history = history_array.stack() self.assertEqual( (expected_time, batch_size, encoder_max_time), tuple(history.get_shape().as_list())) state_alignment_history.append(history) state_alignment_history = tuple(state_alignment_history) else: state_alignment_history = \ final_state.alignment_history.stack() self.assertEqual( (expected_time, batch_size, encoder_max_time), tuple(state_alignment_history.get_shape().as_list())) tf.nest.assert_same_structure( cell.state_size, cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)) # Remove the history from final_state for purposes of the # remainder of the tests. final_state = final_state._replace(alignment_history=()) # pylint: disable=protected-access else: state_alignment_history = () self.evaluate(tf.compat.v1.global_variables_initializer()) eval_result = self.evaluate({ "final_outputs": final_outputs, "final_state": final_state, "state_alignment_history": state_alignment_history, }) final_output_info = tf.nest.map_structure( get_result_summary, eval_result["final_outputs"]) final_state_info = tf.nest.map_structure( get_result_summary, eval_result["final_state"]) print("final_output_info: ", final_output_info) print("final_state_info: ", final_state_info) tf.nest.map_structure(self.assertAllCloseOrEqual, expected_final_output, final_output_info) tf.nest.map_structure(self.assertAllCloseOrEqual, expected_final_state, final_state_info) # by default, the wrapper emits attention as output if alignment_history: final_alignment_history_info = tf.nest.map_structure( get_result_summary, eval_result["state_alignment_history"]) print("final_alignment_history_info: ", final_alignment_history_info) tf.nest.map_structure( self.assertAllCloseOrEqual, # outputs are batch major but the stacked TensorArray is # time major expected_final_alignment_history, final_alignment_history_info)
def _test_with_attention( create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=3, alignment_history=False, expected_final_alignment_history=None, attention_layer_size=6, attention_layer=None, create_query_layer=False, create_memory_layer=True, create_attention_kwargs=None, ): attention_layer_sizes = ([attention_layer_size] if attention_layer_size is not None else None) attention_layers = [attention_layer ] if attention_layer is not None else None create_attention_mechanisms = [create_attention_mechanism] attention_mechanism_depths = [attention_mechanism_depth] assert len(create_attention_mechanisms) == 1 encoder_sequence_length = [3, 2, 3, 1, 1] decoder_sequence_length = [2, 0, 1, 2, 3] batch_size = 5 encoder_max_time = 8 decoder_max_time = 4 input_depth = 7 encoder_output_depth = 10 cell_depth = 9 create_attention_kwargs = create_attention_kwargs or {} if attention_layer_sizes is not None: # Compute sum of attention_layer_sizes. Use encoder_output_depth if # None. attention_depth = sum( attention_layer_size or encoder_output_depth for attention_layer_size in attention_layer_sizes) elif attention_layers is not None: # Compute sum of attention_layers output depth. attention_depth = sum( attention_layer.compute_output_shape( [batch_size, cell_depth + encoder_output_depth]).dims[-1].value for attention_layer in attention_layers) else: attention_depth = encoder_output_depth * len( create_attention_mechanisms) decoder_inputs = np.random.randn(batch_size, decoder_max_time, input_depth).astype(np.float32) encoder_outputs = np.random.randn(batch_size, encoder_max_time, encoder_output_depth).astype(np.float32) attention_mechanisms = [] for creator, depth in zip(create_attention_mechanisms, attention_mechanism_depths): # Create a memory layer with deterministic initializer to avoid # randomness in the test between graph and eager. if create_query_layer: create_attention_kwargs["query_layer"] = tf.keras.layers.Dense( depth, kernel_initializer="ones", use_bias=False) if create_memory_layer: create_attention_kwargs["memory_layer"] = tf.keras.layers.Dense( depth, kernel_initializer="ones", use_bias=False) attention_mechanisms.append( creator( units=depth, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, **create_attention_kwargs, )) attention_layer_size = attention_layer_sizes attention_layer = attention_layers if attention_layer_size is not None: attention_layer_size = attention_layer_size[0] if attention_layer is not None: attention_layer = attention_layer[0] cell = tf.keras.layers.LSTMCell( cell_depth, recurrent_activation="sigmoid", kernel_initializer="ones", recurrent_initializer="ones", ) cell = wrapper.AttentionWrapper( cell, attention_mechanisms[0], attention_layer_size=attention_layer_size, alignment_history=alignment_history, attention_layer=attention_layer, ) if cell._attention_layers is not None: for layer in cell._attention_layers: layer.kernel_initializer = tf.compat.v1.keras.initializers.glorot_uniform( seed=1337) sampler = sampler_py.TrainingSampler() my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler) initial_state = cell.get_initial_state(dtype=tf.float32, batch_size=batch_size) final_outputs, final_state, _ = my_decoder( decoder_inputs, initial_state=initial_state, sequence_length=decoder_sequence_length, ) assert isinstance(final_outputs, basic_decoder.BasicDecoderOutput) assert isinstance(final_state, wrapper.AttentionWrapperState) expected_time = max(decoder_sequence_length) assert (batch_size, expected_time, attention_depth) == tuple( final_outputs.rnn_output.get_shape().as_list()) assert (batch_size, expected_time) == tuple( final_outputs.sample_id.get_shape().as_list()) assert (batch_size, attention_depth) == tuple( final_state.attention.get_shape().as_list()) assert (batch_size, cell_depth) == tuple( final_state.cell_state[0].get_shape().as_list()) assert (batch_size, cell_depth) == tuple( final_state.cell_state[1].get_shape().as_list()) if alignment_history: state_alignment_history = final_state.alignment_history.stack() assert (expected_time, batch_size, encoder_max_time) == tuple( state_alignment_history.get_shape().as_list()) tf.nest.assert_same_structure( cell.state_size, cell.get_initial_state(batch_size=batch_size, dtype=tf.float32), ) # Remove the history from final_state for purposes of the # remainder of the tests. final_state = final_state._replace(alignment_history=()) # pylint: disable=protected-access else: state_alignment_history = () final_outputs = tf.nest.map_structure(np.array, final_outputs) final_state = tf.nest.map_structure(np.array, final_state) state_alignment_history = tf.nest.map_structure(np.array, state_alignment_history) final_output_info = tf.nest.map_structure(get_result_summary, final_outputs) final_state_info = tf.nest.map_structure(get_result_summary, final_state) tf.nest.map_structure(assert_allclose_or_equal, expected_final_output, final_output_info) tf.nest.map_structure(assert_allclose_or_equal, expected_final_state, final_state_info) # by default, the wrapper emits attention as output if alignment_history: final_alignment_history_info = tf.nest.map_structure( get_result_summary, state_alignment_history) tf.nest.map_structure( assert_allclose_or_equal, # outputs are batch major but the stacked TensorArray is # time major expected_final_alignment_history, final_alignment_history_info, )
def _testDynamicDecodeRNN(self, time_major, has_attention, with_alignment_history=False): encoder_sequence_length = np.array([3, 2, 3, 1, 1]) decoder_sequence_length = np.array([2, 0, 1, 2, 3]) batch_size = 5 decoder_max_time = 4 input_depth = 7 cell_depth = 9 attention_depth = 6 vocab_size = 20 end_token = vocab_size - 1 start_token = 0 embedding_dim = 50 max_out = max(decoder_sequence_length) output_layer = tf.keras.layers.Dense(vocab_size, use_bias=True, activation=None) beam_width = 3 with self.cached_session(): batch_size_tensor = tf.constant(batch_size) embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) cell = tf.keras.layers.LSTMCell(cell_depth) initial_state = cell.get_initial_state(batch_size=batch_size, dtype=tf.float32) coverage_penalty_weight = 0.0 if has_attention: coverage_penalty_weight = 0.2 inputs = tf.compat.v1.placeholder_with_default( np.random.randn(batch_size, decoder_max_time, input_depth).astype(np.float32), shape=(None, None, input_depth)) tiled_inputs = beam_search_decoder.tile_batch( inputs, multiplier=beam_width) tiled_sequence_length = beam_search_decoder.tile_batch( encoder_sequence_length, multiplier=beam_width) attention_mechanism = attention_wrapper.BahdanauAttention( units=attention_depth, memory=tiled_inputs, memory_sequence_length=tiled_sequence_length) initial_state = beam_search_decoder.tile_batch( initial_state, multiplier=beam_width) cell = attention_wrapper.AttentionWrapper( cell=cell, attention_mechanism=attention_mechanism, attention_layer_size=attention_depth, alignment_history=with_alignment_history) cell_state = cell.get_initial_state(batch_size=batch_size_tensor * beam_width, dtype=tf.float32) if has_attention: cell_state = cell_state.clone(cell_state=initial_state) bsd = beam_search_decoder.BeamSearchDecoder( cell=cell, beam_width=beam_width, output_layer=output_layer, length_penalty_weight=0.0, coverage_penalty_weight=coverage_penalty_weight, output_time_major=time_major, maximum_iterations=max_out) final_outputs, final_state, final_sequence_lengths = bsd( embedding, start_tokens=tf.fill([batch_size_tensor], start_token), end_token=end_token, initial_state=cell_state) def _t(shape): if time_major: return (shape[1], shape[0]) + shape[2:] return shape self.assertIsInstance( final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) self.assertIsInstance(final_state, beam_search_decoder.BeamSearchDecoderState) beam_search_decoder_output = \ final_outputs.beam_search_decoder_output expected_seq_length = 3 if tf.executing_eagerly() else None self.assertEqual( _t((batch_size, expected_seq_length, beam_width)), tuple(beam_search_decoder_output.scores.get_shape().as_list())) self.assertEqual( _t((batch_size, expected_seq_length, beam_width)), tuple(final_outputs.predicted_ids.get_shape().as_list())) self.evaluate(tf.compat.v1.global_variables_initializer()) eval_results = self.evaluate({ 'final_outputs': final_outputs, 'final_sequence_lengths': final_sequence_lengths }) max_sequence_length = np.max( eval_results['final_sequence_lengths']) # A smoke test self.assertEqual( _t((batch_size, max_sequence_length, beam_width)), eval_results['final_outputs'].beam_search_decoder_output. scores.shape) self.assertEqual( _t((batch_size, max_sequence_length, beam_width)), eval_results['final_outputs'].beam_search_decoder_output. predicted_ids.shape)
def test_beam_search_decoder(cell_class, time_major, has_attention, with_alignment_history): encoder_sequence_length = np.array([3, 2, 3, 1, 1]) batch_size = 5 decoder_max_time = 4 input_depth = 7 cell_depth = 9 attention_depth = 6 vocab_size = 20 end_token = vocab_size - 1 start_token = 0 embedding_dim = 50 maximum_iterations = 3 output_layer = tf.keras.layers.Dense(vocab_size, use_bias=True, activation=None) beam_width = 3 embedding = tf.random.normal([vocab_size, embedding_dim]) cell = cell_class(cell_depth) if has_attention: attention_mechanism = attention_wrapper.BahdanauAttention( units=attention_depth, ) cell = attention_wrapper.AttentionWrapper( cell=cell, attention_mechanism=attention_mechanism, attention_layer_size=attention_depth, alignment_history=with_alignment_history, ) coverage_penalty_weight = 0.2 else: coverage_penalty_weight = 0.0 bsd = beam_search_decoder.BeamSearchDecoder( cell=cell, beam_width=beam_width, output_layer=output_layer, length_penalty_weight=0.0, coverage_penalty_weight=coverage_penalty_weight, output_time_major=time_major, maximum_iterations=maximum_iterations, ) @tf.function(input_signature=( tf.TensorSpec([None, None, input_depth], dtype=tf.float32), tf.TensorSpec([None], dtype=tf.int32), )) def _beam_decode_from(memory, memory_sequence_length): batch_size_tensor = tf.shape(memory)[0] if has_attention: tiled_memory = beam_search_decoder.tile_batch( memory, multiplier=beam_width) tiled_memory_sequence_length = beam_search_decoder.tile_batch( memory_sequence_length, multiplier=beam_width) attention_mechanism.setup_memory( tiled_memory, memory_sequence_length=tiled_memory_sequence_length) cell_state = cell.get_initial_state(batch_size=batch_size_tensor * beam_width, dtype=tf.float32) return bsd( embedding, start_tokens=tf.fill([batch_size_tensor], start_token), end_token=end_token, initial_state=cell_state, ) memory = tf.random.normal([batch_size, decoder_max_time, input_depth]) memory_sequence_length = tf.constant(encoder_sequence_length, dtype=tf.int32) final_outputs, final_state, final_sequence_lengths = _beam_decode_from( memory, memory_sequence_length) def _t(shape): if time_major: return (shape[1], shape[0]) + shape[2:] return shape assert isinstance(final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) assert isinstance(final_state, beam_search_decoder.BeamSearchDecoderState) beam_search_decoder_output = final_outputs.beam_search_decoder_output max_sequence_length = np.max(final_sequence_lengths.numpy()) assert _t((batch_size, max_sequence_length, beam_width)) == tuple( beam_search_decoder_output.scores.shape.as_list()) assert _t( (batch_size, max_sequence_length, beam_width)) == tuple(final_outputs.predicted_ids.shape.as_list())
def test_dynamic_decode_rnn(cell_class, time_major, has_attention, with_alignment_history): encoder_sequence_length = np.array([3, 2, 3, 1, 1]) decoder_sequence_length = np.array([2, 0, 1, 2, 3]) batch_size = 5 decoder_max_time = 4 input_depth = 7 cell_depth = 9 attention_depth = 6 vocab_size = 20 end_token = vocab_size - 1 start_token = 0 embedding_dim = 50 max_out = max(decoder_sequence_length) output_layer = tf.keras.layers.Dense(vocab_size, use_bias=True, activation=None) beam_width = 3 batch_size_tensor = tf.constant(batch_size) embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) cell = cell_class(cell_depth) initial_state = cell.get_initial_state(batch_size=batch_size, dtype=tf.float32) coverage_penalty_weight = 0.0 if has_attention: coverage_penalty_weight = 0.2 inputs = tf.compat.v1.placeholder_with_default( np.random.randn(batch_size, decoder_max_time, input_depth).astype(np.float32), shape=(None, None, input_depth), ) tiled_inputs = beam_search_decoder.tile_batch(inputs, multiplier=beam_width) tiled_sequence_length = beam_search_decoder.tile_batch( encoder_sequence_length, multiplier=beam_width) attention_mechanism = attention_wrapper.BahdanauAttention( units=attention_depth, memory=tiled_inputs, memory_sequence_length=tiled_sequence_length, ) initial_state = beam_search_decoder.tile_batch(initial_state, multiplier=beam_width) cell = attention_wrapper.AttentionWrapper( cell=cell, attention_mechanism=attention_mechanism, attention_layer_size=attention_depth, alignment_history=with_alignment_history, ) cell_state = cell.get_initial_state(batch_size=batch_size_tensor * beam_width, dtype=tf.float32) if has_attention: cell_state = cell_state.clone(cell_state=initial_state) bsd = beam_search_decoder.BeamSearchDecoder( cell=cell, beam_width=beam_width, output_layer=output_layer, length_penalty_weight=0.0, coverage_penalty_weight=coverage_penalty_weight, output_time_major=time_major, maximum_iterations=max_out, ) final_outputs, final_state, final_sequence_lengths = bsd( embedding, start_tokens=tf.fill([batch_size_tensor], start_token), end_token=end_token, initial_state=cell_state, ) def _t(shape): if time_major: return (shape[1], shape[0]) + shape[2:] return shape assert isinstance(final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) assert isinstance(final_state, beam_search_decoder.BeamSearchDecoderState) beam_search_decoder_output = final_outputs.beam_search_decoder_output expected_seq_length = 3 if tf.executing_eagerly() else None assert _t((batch_size, expected_seq_length, beam_width)) == tuple( beam_search_decoder_output.scores.shape.as_list()) assert _t( (batch_size, expected_seq_length, beam_width)) == tuple(final_outputs.predicted_ids.shape.as_list()) eval_results = { "final_outputs": final_outputs, "final_sequence_lengths": final_sequence_lengths.numpy(), } max_sequence_length = np.max(eval_results["final_sequence_lengths"]) # A smoke test assert (_t( (batch_size, max_sequence_length, beam_width) ) == eval_results["final_outputs"].beam_search_decoder_output.scores.shape) assert (_t((batch_size, max_sequence_length, beam_width)) == eval_results["final_outputs"]. beam_search_decoder_output.predicted_ids.shape)