def test_step_with_eos(self): beam_state = beam_search.BeamSearchState( log_probs=tf.nn.log_softmax(tf.ones(self.config.beam_width)), lengths=tf.convert_to_tensor([2, 1, 2], dtype=tf.int32), finished=tf.constant([False, True, False], dtype=tf.bool)) logits_ = np.full([self.config.beam_width, self.config.vocab_size], 0.0001) logits_[0, 2] = 1.1 logits_[1, 2] = 1.0 logits_[2, 2] = 1.0 logits = tf.convert_to_tensor(logits_, dtype=tf.float32) log_probs = tf.nn.log_softmax(logits) outputs, next_beam_state = beam_search.beam_search_step( time_=2, logits=logits, beam_state=beam_state, config=self.config) with self.test_session() as sess: outputs_, next_state_, state_, log_probs_ = sess.run( [outputs, next_beam_state, beam_state, log_probs]) np.testing.assert_array_equal(outputs_.predicted_ids, [0, 2, 2]) np.testing.assert_array_equal(outputs_.beam_parent_ids, [1, 0, 2]) np.testing.assert_array_equal(next_state_.lengths, [1, 3, 3]) np.testing.assert_array_equal(next_state_.finished, [True, False, False]) expected_log_probs = state_.log_probs[outputs_.beam_parent_ids] expected_log_probs[1] += log_probs_[0, 2] expected_log_probs[2] += log_probs_[2, 2] np.testing.assert_array_equal(next_state_.log_probs, expected_log_probs)
def test_step_with_eos(self): beam_state = beam_search.BeamState( time=tf.constant(2), log_probs=tf.nn.log_softmax(tf.ones(self.config.beam_width)), scores=tf.nn.log_softmax(tf.ones(self.config.beam_width)), predicted_ids=tf.convert_to_tensor([[1, 2, -1, -1, -1], [3, 0, -1, -1, -1], [5, 6, -1, -1, -1]]), beam_parent_ids=tf.zeros(self.config.beam_width)) logits = tf.sparse_to_dense( [[0, 2], [1, 2], [2, 2]], output_shape=[self.config.beam_width, self.config.vocab_size], sparse_values=[1.0, 1.0, 1.0], default_value=0.0001) next_beam_state = beam_search.beam_search_step(logits=logits, beam_state=beam_state, config=self.config) with self.test_session() as sess: res = sess.run(next_beam_state) expected_predictions = np.array([[3, 0, 0, -1, -1], [1, 2, 2, -1, -1], [5, 6, 2, -1, -1]]) np.testing.assert_array_equal(res.predicted_ids, expected_predictions) previous_log_probs = sess.run(beam_state.log_probs) np.testing.assert_array_equal(res.log_probs[0], previous_log_probs[0])
def step(self, time, inputs, state, name=None): cur_inputs = inputs[:,0:time+1,:] zeros_padding = inputs[:,time+2:,:] cur_inputs_pos = self.add_position_embedding(cur_inputs, time) enc_output, beam_state = state logits = self.infer_conv_block(enc_output, cur_inputs_pos) bs_output, beam_state = beam_search.beam_search_step( time_=time, logits=logits, beam_state=beam_state, config=self.config) finished, next_inputs = self.next_inputs(sample_ids=bs_output.predicted_ids) next_inputs = tf.reshape(next_inputs, [self.config.beam_width, 1, inputs.get_shape().as_list()[-1]]) next_inputs = tf.concat([cur_inputs, next_inputs], axis=1) next_inputs = tf.concat([next_inputs, zeros_padding], axis=1) next_inputs.set_shape([self.config.beam_width, self.params['max_decode_length'], inputs.get_shape().as_list()[-1]]) outputs = BeamDecoderOutput( logits=tf.zeros([self.config.beam_width, self.config.vocab_size]), predicted_ids=bs_output.predicted_ids, log_probs=beam_state.log_probs, scores=bs_output.scores, beam_parent_ids=bs_output.beam_parent_ids) return outputs, (enc_output,beam_state), next_inputs, finished
def step(self, time, inputs, state, name=None): cur_inputs = inputs[:, 0:time + 1, :] zeros_padding = inputs[:, time + 2:, :] cur_inputs_pos = self.add_position_embedding(cur_inputs, time) enc_output, beam_state = state logits = self.infer_conv_block(enc_output, cur_inputs_pos) bs_output, beam_state = beam_search.beam_search_step( time_=time, logits=logits, beam_state=beam_state, config=self.config) finished, next_inputs = self.next_inputs( sample_ids=bs_output.predicted_ids) next_inputs = tf.reshape( next_inputs, [self.config.beam_width, 1, inputs.get_shape().as_list()[-1]]) next_inputs = tf.concat([cur_inputs, next_inputs], axis=1) next_inputs = tf.concat([next_inputs, zeros_padding], axis=1) next_inputs.set_shape([ self.config.beam_width, self.params['max_decode_length'], inputs.get_shape().as_list()[-1] ]) outputs = BeamDecoderOutput(logits=tf.zeros( [self.config.beam_width, self.config.vocab_size]), predicted_ids=bs_output.predicted_ids, log_probs=beam_state.log_probs, scores=bs_output.scores, beam_parent_ids=bs_output.beam_parent_ids) return outputs, (enc_output, beam_state), next_inputs, finished
def step(self, time_, inputs, state, name=None): decoder_state, beam_state = state # Call the original decoder (decoder_output, decoder_state, _, _) = self.decoder.step(time_, inputs, decoder_state) # Perform a step of beam search bs_output, beam_state = beam_search.beam_search_step( time_=time_, logits=decoder_output.logits, beam_state=beam_state, config=self.config) # Shuffle everything according to beam search result if isinstance(decoder_state, AttentionWrapperState): beam_cell_state = nest.map_structure( lambda x: tf.gather(x, bs_output.beam_parent_ids), decoder_state.cell_state) decoder_state = decoder_state.clone(cell_state=beam_cell_state) else: decoder_state = nest.map_structure( lambda x: tf.gather(x, bs_output.beam_parent_ids), decoder_state) decoder_output = nest.map_structure( lambda x: tf.gather(x, bs_output.beam_parent_ids), decoder_output) next_state = (decoder_state, beam_state) outputs = BeamDecoderOutput(logits=tf.zeros( [self.config.beam_width, self.config.vocab_size]), predicted_ids=bs_output.predicted_ids, log_probs=beam_state.log_probs, scores=bs_output.scores, beam_parent_ids=bs_output.beam_parent_ids, original_outputs=decoder_output) finished, next_inputs, next_state = self.decoder.helper.next_inputs( time=time_, outputs=decoder_output, state=next_state, sample_ids=bs_output.predicted_ids) next_inputs.set_shape([self.batch_size, None]) return (outputs, next_state, next_inputs, finished)
def step(self, time_, cell_output, cell_state, loop_state): initial_call = (cell_output is None) if initial_call: cell_output = tf.zeros( [self.config.beam_width, self.cell.output_size]) # We start out with all beams being equal, so we tile the cell state # [beam_width] times next_cell_state = beam_search.nest_map( cell_state, lambda x: tf.tile(x, [self.config.beam_width, 1])) # Call the original decoder original_outputs = self.decoder.step(time_, None, cell_state, loop_state) # Create an initial Beam State beam_state = beam_search.create_initial_beam_state( config=self.config, max_time=self.decoder.max_decode_length) next_loop_state = self._wrap_loop_state( beam_state, original_outputs.next_loop_state) outputs = self.output_shapes() else: prev_beam_state, original_loop_state = self._unwrap_loop_state( loop_state) # Call the original decoder original_outputs = self.decoder.step(time_, cell_output, cell_state, original_loop_state) # Perform a step of beam search beam_state = beam_search.beam_search_step( logits=original_outputs.outputs.logits, beam_state=prev_beam_state, config=self.config) beam_state.predicted_ids.set_shape( [None, self.decoder.max_decode_length]) next_loop_state = self._wrap_loop_state( beam_state, original_outputs.next_loop_state) outputs = BeamDecoderOutput( logits=tf.zeros( [self.config.beam_width, self.config.vocab_size]), predicted_ids=tf.to_int64(beam_state.predicted_ids[:, time_ - 1]), log_probs=beam_state.log_probs, scores=beam_state.scores, beam_parent_ids=beam_state.beam_parent_ids, original_outputs=original_outputs.outputs) # Cell states are shuffled around by beam search next_cell_state = beam_search.nest_map( original_outputs.next_cell_state, lambda x: tf.gather(x, beam_state.beam_parent_ids)) # The final step output step_output = DecoderStepOutput(outputs=outputs, next_cell_state=next_cell_state, next_loop_state=next_loop_state) return step_output