def test_missing_embedding_fn(): batch_size = 6 beam_width = 4 cell = tf.keras.layers.LSTMCell(5) decoder = beam_search_decoder.BeamSearchDecoder(cell, beam_width=beam_width) initial_state = cell.get_initial_state(batch_size=batch_size * beam_width, dtype=tf.float32) start_tokens = tf.ones([batch_size], dtype=tf.int32) end_token = tf.constant(2, dtype=tf.int32) with pytest.raises(ValueError): decoder(None, start_tokens, end_token, initial_state)
def _testDynamicDecodeRNN(self, time_major, has_attention, with_alignment_history=False): encoder_sequence_length = np.array([3, 2, 3, 1, 1]) decoder_sequence_length = np.array([2, 0, 1, 2, 3]) batch_size = 5 decoder_max_time = 4 input_depth = 7 cell_depth = 9 attention_depth = 6 vocab_size = 20 end_token = vocab_size - 1 start_token = 0 embedding_dim = 50 max_out = max(decoder_sequence_length) output_layer = tf.keras.layers.Dense(vocab_size, use_bias=True, activation=None) beam_width = 3 with self.cached_session(): batch_size_tensor = tf.constant(batch_size) embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) cell = tf.keras.layers.LSTMCell(cell_depth) initial_state = cell.get_initial_state(batch_size=batch_size, dtype=tf.float32) coverage_penalty_weight = 0.0 if has_attention: coverage_penalty_weight = 0.2 inputs = tf.compat.v1.placeholder_with_default( np.random.randn(batch_size, decoder_max_time, input_depth).astype(np.float32), shape=(None, None, input_depth)) tiled_inputs = beam_search_decoder.tile_batch( inputs, multiplier=beam_width) tiled_sequence_length = beam_search_decoder.tile_batch( encoder_sequence_length, multiplier=beam_width) attention_mechanism = attention_wrapper.BahdanauAttention( units=attention_depth, memory=tiled_inputs, memory_sequence_length=tiled_sequence_length) initial_state = beam_search_decoder.tile_batch( initial_state, multiplier=beam_width) cell = attention_wrapper.AttentionWrapper( cell=cell, attention_mechanism=attention_mechanism, attention_layer_size=attention_depth, alignment_history=with_alignment_history) cell_state = cell.get_initial_state(batch_size=batch_size_tensor * beam_width, dtype=tf.float32) if has_attention: cell_state = cell_state.clone(cell_state=initial_state) bsd = beam_search_decoder.BeamSearchDecoder( cell=cell, beam_width=beam_width, output_layer=output_layer, length_penalty_weight=0.0, coverage_penalty_weight=coverage_penalty_weight, output_time_major=time_major, maximum_iterations=max_out) final_outputs, final_state, final_sequence_lengths = bsd( embedding, start_tokens=tf.fill([batch_size_tensor], start_token), end_token=end_token, initial_state=cell_state) def _t(shape): if time_major: return (shape[1], shape[0]) + shape[2:] return shape self.assertIsInstance( final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) self.assertIsInstance(final_state, beam_search_decoder.BeamSearchDecoderState) beam_search_decoder_output = \ final_outputs.beam_search_decoder_output expected_seq_length = 3 if tf.executing_eagerly() else None self.assertEqual( _t((batch_size, expected_seq_length, beam_width)), tuple(beam_search_decoder_output.scores.get_shape().as_list())) self.assertEqual( _t((batch_size, expected_seq_length, beam_width)), tuple(final_outputs.predicted_ids.get_shape().as_list())) self.evaluate(tf.compat.v1.global_variables_initializer()) eval_results = self.evaluate({ 'final_outputs': final_outputs, 'final_sequence_lengths': final_sequence_lengths }) max_sequence_length = np.max( eval_results['final_sequence_lengths']) # A smoke test self.assertEqual( _t((batch_size, max_sequence_length, beam_width)), eval_results['final_outputs'].beam_search_decoder_output. scores.shape) self.assertEqual( _t((batch_size, max_sequence_length, beam_width)), eval_results['final_outputs'].beam_search_decoder_output. predicted_ids.shape)
def test_beam_search_decoder(cell_class, time_major, has_attention, with_alignment_history): encoder_sequence_length = np.array([3, 2, 3, 1, 1]) batch_size = 5 decoder_max_time = 4 input_depth = 7 cell_depth = 9 attention_depth = 6 vocab_size = 20 end_token = vocab_size - 1 start_token = 0 embedding_dim = 50 maximum_iterations = 3 output_layer = tf.keras.layers.Dense(vocab_size, use_bias=True, activation=None) beam_width = 3 embedding = tf.random.normal([vocab_size, embedding_dim]) cell = cell_class(cell_depth) if has_attention: attention_mechanism = attention_wrapper.BahdanauAttention( units=attention_depth, ) cell = attention_wrapper.AttentionWrapper( cell=cell, attention_mechanism=attention_mechanism, attention_layer_size=attention_depth, alignment_history=with_alignment_history, ) coverage_penalty_weight = 0.2 else: coverage_penalty_weight = 0.0 bsd = beam_search_decoder.BeamSearchDecoder( cell=cell, beam_width=beam_width, output_layer=output_layer, length_penalty_weight=0.0, coverage_penalty_weight=coverage_penalty_weight, output_time_major=time_major, maximum_iterations=maximum_iterations, ) @tf.function(input_signature=( tf.TensorSpec([None, None, input_depth], dtype=tf.float32), tf.TensorSpec([None], dtype=tf.int32), )) def _beam_decode_from(memory, memory_sequence_length): batch_size_tensor = tf.shape(memory)[0] if has_attention: tiled_memory = beam_search_decoder.tile_batch( memory, multiplier=beam_width) tiled_memory_sequence_length = beam_search_decoder.tile_batch( memory_sequence_length, multiplier=beam_width) attention_mechanism.setup_memory( tiled_memory, memory_sequence_length=tiled_memory_sequence_length) cell_state = cell.get_initial_state(batch_size=batch_size_tensor * beam_width, dtype=tf.float32) return bsd( embedding, start_tokens=tf.fill([batch_size_tensor], start_token), end_token=end_token, initial_state=cell_state, ) memory = tf.random.normal([batch_size, decoder_max_time, input_depth]) memory_sequence_length = tf.constant(encoder_sequence_length, dtype=tf.int32) final_outputs, final_state, final_sequence_lengths = _beam_decode_from( memory, memory_sequence_length) def _t(shape): if time_major: return (shape[1], shape[0]) + shape[2:] return shape assert isinstance(final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) assert isinstance(final_state, beam_search_decoder.BeamSearchDecoderState) beam_search_decoder_output = final_outputs.beam_search_decoder_output max_sequence_length = np.max(final_sequence_lengths.numpy()) assert _t((batch_size, max_sequence_length, beam_width)) == tuple( beam_search_decoder_output.scores.shape.as_list()) assert _t( (batch_size, max_sequence_length, beam_width)) == tuple(final_outputs.predicted_ids.shape.as_list())
def test_dynamic_decode_rnn(cell_class, time_major, has_attention, with_alignment_history): encoder_sequence_length = np.array([3, 2, 3, 1, 1]) decoder_sequence_length = np.array([2, 0, 1, 2, 3]) batch_size = 5 decoder_max_time = 4 input_depth = 7 cell_depth = 9 attention_depth = 6 vocab_size = 20 end_token = vocab_size - 1 start_token = 0 embedding_dim = 50 max_out = max(decoder_sequence_length) output_layer = tf.keras.layers.Dense(vocab_size, use_bias=True, activation=None) beam_width = 3 batch_size_tensor = tf.constant(batch_size) embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) cell = cell_class(cell_depth) initial_state = cell.get_initial_state(batch_size=batch_size, dtype=tf.float32) coverage_penalty_weight = 0.0 if has_attention: coverage_penalty_weight = 0.2 inputs = tf.compat.v1.placeholder_with_default( np.random.randn(batch_size, decoder_max_time, input_depth).astype(np.float32), shape=(None, None, input_depth), ) tiled_inputs = beam_search_decoder.tile_batch(inputs, multiplier=beam_width) tiled_sequence_length = beam_search_decoder.tile_batch( encoder_sequence_length, multiplier=beam_width) attention_mechanism = attention_wrapper.BahdanauAttention( units=attention_depth, memory=tiled_inputs, memory_sequence_length=tiled_sequence_length, ) initial_state = beam_search_decoder.tile_batch(initial_state, multiplier=beam_width) cell = attention_wrapper.AttentionWrapper( cell=cell, attention_mechanism=attention_mechanism, attention_layer_size=attention_depth, alignment_history=with_alignment_history, ) cell_state = cell.get_initial_state(batch_size=batch_size_tensor * beam_width, dtype=tf.float32) if has_attention: cell_state = cell_state.clone(cell_state=initial_state) bsd = beam_search_decoder.BeamSearchDecoder( cell=cell, beam_width=beam_width, output_layer=output_layer, length_penalty_weight=0.0, coverage_penalty_weight=coverage_penalty_weight, output_time_major=time_major, maximum_iterations=max_out, ) final_outputs, final_state, final_sequence_lengths = bsd( embedding, start_tokens=tf.fill([batch_size_tensor], start_token), end_token=end_token, initial_state=cell_state, ) def _t(shape): if time_major: return (shape[1], shape[0]) + shape[2:] return shape assert isinstance(final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) assert isinstance(final_state, beam_search_decoder.BeamSearchDecoderState) beam_search_decoder_output = final_outputs.beam_search_decoder_output expected_seq_length = 3 if tf.executing_eagerly() else None assert _t((batch_size, expected_seq_length, beam_width)) == tuple( beam_search_decoder_output.scores.shape.as_list()) assert _t( (batch_size, expected_seq_length, beam_width)) == tuple(final_outputs.predicted_ids.shape.as_list()) eval_results = { "final_outputs": final_outputs, "final_sequence_lengths": final_sequence_lengths.numpy(), } max_sequence_length = np.max(eval_results["final_sequence_lengths"]) # A smoke test assert (_t( (batch_size, max_sequence_length, beam_width) ) == eval_results["final_outputs"].beam_search_decoder_output.scores.shape) assert (_t((batch_size, max_sequence_length, beam_width)) == eval_results["final_outputs"]. beam_search_decoder_output.predicted_ids.shape)