def test_step(self): dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) beam_state = beam_search_decoder.BeamSearchDecoderState( cell_state=dummy_cell_state, log_probs=nn_ops.log_softmax( array_ops.ones([self.batch_size, self.beam_width])), lengths=constant_op.constant( 2, shape=[self.batch_size, self.beam_width], dtype=dtypes.int32), finished=array_ops.zeros([self.batch_size, self.beam_width], dtype=dtypes.bool)) logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], 0.0001) logits_[0, 0, 2] = 1.9 logits_[0, 0, 3] = 2.1 logits_[0, 1, 3] = 3.1 logits_[0, 1, 4] = 0.9 logits_[1, 0, 1] = 0.5 logits_[1, 1, 2] = 2.7 logits_[1, 2, 2] = 10.0 logits_[1, 2, 3] = 0.2 logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32) log_probs = nn_ops.log_softmax(logits) outputs, next_beam_state = beam_search_decoder._beam_search_step( time=2, logits=logits, beam_state=beam_state, batch_size=ops.convert_to_tensor(self.batch_size), beam_width=self.beam_width, end_token=self.end_token, length_penalty_weight=self.length_penalty_weight) with self.test_session() as sess: outputs_, next_state_, state_, log_probs_ = sess.run( [outputs, next_beam_state, beam_state, log_probs]) np.testing.assert_array_equal(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]]) np.testing.assert_array_equal(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]]) np.testing.assert_array_equal(next_state_.lengths, [[3, 3, 3], [3, 3, 3]]) np.testing.assert_array_equal( next_state_.finished, [[False, False, False], [False, False, False]]) expected_log_probs = [] expected_log_probs.append(state_.log_probs[0][[1, 0, 0]]) expected_log_probs.append(state_.log_probs[1][[2, 1, 0]]) # 0 --> 1 expected_log_probs[0][0] += log_probs_[0, 1, 3] expected_log_probs[0][1] += log_probs_[0, 0, 3] expected_log_probs[0][2] += log_probs_[0, 0, 2] expected_log_probs[1][0] += log_probs_[1, 2, 2] expected_log_probs[1][1] += log_probs_[1, 1, 2] expected_log_probs[1][2] += log_probs_[1, 0, 1] np.testing.assert_array_equal(next_state_.log_probs, expected_log_probs)
def test_step_with_eos(self): dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) beam_state = beam_search_decoder.BeamSearchDecoderState( cell_state=dummy_cell_state, log_probs=nn_ops.log_softmax( array_ops.ones([self.batch_size, self.beam_width])), lengths=ops.convert_to_tensor([[2, 1, 2], [2, 2, 1]], dtype=dtypes.int64), finished=ops.convert_to_tensor( [[False, True, False], [False, False, True]], dtype=dtypes.bool), accumulated_attention_probs=()) logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], 0.0001) logits_[0, 0, 2] = 1.9 logits_[0, 0, 3] = 2.1 logits_[0, 1, 3] = 3.1 logits_[0, 1, 4] = 0.9 logits_[1, 0, 1] = 0.5 logits_[1, 1, 2] = 5.7 # why does this not work when it's 2.7? logits_[1, 2, 2] = 1.0 logits_[1, 2, 3] = 0.2 logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32) log_probs = nn_ops.log_softmax(logits) outputs, next_beam_state = beam_search_decoder._beam_search_step( time=2, logits=logits, next_cell_state=dummy_cell_state, beam_state=beam_state, batch_size=ops.convert_to_tensor(self.batch_size), beam_width=self.beam_width, end_token=self.end_token, length_penalty_weight=self.length_penalty_weight, coverage_penalty_weight=self.coverage_penalty_weight) with self.cached_session() as sess: outputs_, next_state_, state_, log_probs_ = sess.run( [outputs, next_beam_state, beam_state, log_probs]) self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]]) self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]]) self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]]) self.assertAllEqual(next_state_.finished, [[True, False, False], [False, True, False]]) expected_log_probs = [] expected_log_probs.append(state_.log_probs[0][[1, 0, 0]]) expected_log_probs.append(state_.log_probs[1][[1, 2, 0]]) expected_log_probs[0][1] += log_probs_[0, 0, 3] expected_log_probs[0][2] += log_probs_[0, 0, 2] expected_log_probs[1][0] += log_probs_[1, 1, 2] expected_log_probs[1][2] += log_probs_[1, 0, 1] self.assertAllEqual(next_state_.log_probs, expected_log_probs)
def initialize(self, name=None): """Initialize the decoder. Args: name: Name scope for any created operations. Returns: `(finished, start_inputs, initial_state)`. """ finished, start_inputs = self._finished, self._start_inputs dtype = nest.flatten(self._initial_cell_state)[0].dtype if self._start_token_logits is None: log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) array_ops.zeros([self._batch_size], dtype=dtypes.int32), depth=self._beam_width, on_value=ops.convert_to_tensor(0.0, dtype=dtype), off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), dtype=dtype) else: log_probs = self._start_token_logits sequence_lengths = array_ops.zeros( [self._batch_size, self._beam_width], dtype=dtypes.int64) # Start tokens are part of output if no _GO token used. Make changes accordingly if not self._use_go_tokens: finished = math_ops.equal(self._start_tokens, self._raw_end_token) sequence_lengths = array_ops.where( math_ops.logical_not(finished), array_ops.fill(array_ops.shape(sequence_lengths), tf.constant(1, dtype=dtypes.int64)), sequence_lengths) init_attention_probs = beam_search_decoder.get_attention_probs( self._initial_cell_state, self._coverage_penalty_weight) if init_attention_probs is None: init_attention_probs = () initial_state = beam_search_decoder.BeamSearchDecoderState( cell_state=self._initial_cell_state, log_probs=log_probs, finished=finished, lengths=sequence_lengths, accumulated_attention_probs=init_attention_probs) return (finished, start_inputs, initial_state)
def test_step(self): def get_probs(): """this simulates the initialize method in BeamSearchDecoder.""" log_prob_mask = array_ops.one_hot(array_ops.zeros( [self.batch_size], dtype=dtypes.int32), depth=self.beam_width, on_value=True, off_value=False, dtype=dtypes.bool) log_prob_zeros = array_ops.zeros( [self.batch_size, self.beam_width], dtype=dtypes.float32) log_prob_neg_inf = array_ops.ones( [self.batch_size, self.beam_width], dtype=dtypes.float32) * -np.Inf log_probs = array_ops.where(log_prob_mask, log_prob_zeros, log_prob_neg_inf) return log_probs log_probs = get_probs() dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) # pylint: disable=invalid-name _finished = array_ops.one_hot(array_ops.zeros([self.batch_size], dtype=dtypes.int32), depth=self.beam_width, on_value=False, off_value=True, dtype=dtypes.bool) _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64) _lengths[:, 0] = 2 _lengths = constant_op.constant(_lengths, dtype=dtypes.int64) beam_state = beam_search_decoder.BeamSearchDecoderState( cell_state=dummy_cell_state, log_probs=log_probs, lengths=_lengths, finished=_finished) logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], 0.0001) logits_[0, 0, 2] = 1.9 logits_[0, 0, 3] = 2.1 logits_[0, 1, 3] = 3.1 logits_[0, 1, 4] = 0.9 logits_[1, 0, 1] = 0.5 logits_[1, 1, 2] = 2.7 logits_[1, 2, 2] = 10.0 logits_[1, 2, 3] = 0.2 logits = constant_op.constant(logits_, dtype=dtypes.float32) log_probs = nn_ops.log_softmax(logits) outputs, next_beam_state = beam_search_decoder._beam_search_step( time=2, logits=logits, next_cell_state=dummy_cell_state, beam_state=beam_state, batch_size=ops.convert_to_tensor(self.batch_size), beam_width=self.beam_width, end_token=self.end_token, length_penalty_weight=self.length_penalty_weight) with self.test_session() as sess: outputs_, next_state_, _, _ = sess.run( [outputs, next_beam_state, beam_state, log_probs]) self.assertEqual(outputs_.predicted_ids[0, 0], 3) self.assertEqual(outputs_.predicted_ids[0, 1], 2) self.assertEqual(outputs_.predicted_ids[1, 0], 1) neg_inf = -np.Inf self.assertAllEqual( next_state_.log_probs[:, -3:], [[neg_inf, neg_inf, neg_inf], [neg_inf, neg_inf, neg_inf]]) self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True) self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True) self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]])
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, beam_width, end_token, length_penalty_weight, coverage_penalty_weight): """Performs a single step of Beam Search Decoding. Args: time: Beam search time step, should start at 0. At time 0 we assume that all beams are equal and consider only the first beam for continuations. logits: Logits at the current time step. A tensor of shape `[batch_size, beam_width, vocab_size]` next_cell_state: The next state from the cell, e.g. an instance of AttentionWrapperState if the cell is attentional. beam_state: Current state of the beam search. An instance of `BeamSearchDecoderState`. batch_size: The batch size for this input. beam_width: Python int. The size of the beams. end_token: The int32 end token. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. coverage_penalty_weight: Float weight to penalize the coverage of source sentence. Disabled with 0.0. Returns: A new beam state. """ static_batch_size = tensor_util.constant_value(batch_size) # Calculate the current lengths of the predictions prediction_lengths = beam_state.lengths previously_finished = beam_state.finished not_finished = math_ops.logical_not(previously_finished) # Calculate the total log probs for the new hypotheses # Final Shape: [batch_size, beam_width, vocab_size] step_log_probs = nn_ops.log_softmax(logits) step_log_probs = _mask_probs(step_log_probs, end_token, previously_finished) total_probs = array_ops.expand_dims(beam_state.log_probs, 2) + step_log_probs # Calculate the continuation lengths by adding to all continuing beams. vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1] lengths_to_add = array_ops.one_hot(indices=array_ops.fill( [batch_size, beam_width], end_token), depth=vocab_size, on_value=np.int64(0), off_value=np.int64(1), dtype=dtypes.int64) add_mask = math_ops.to_int64(not_finished) lengths_to_add *= array_ops.expand_dims(add_mask, 2) new_prediction_lengths = (lengths_to_add + array_ops.expand_dims(prediction_lengths, 2)) # Calculate the accumulated attention probabilities if coverage penalty is # enabled. accumulated_attention_probs = None attention_probs = get_attention_probs(next_cell_state, coverage_penalty_weight) if attention_probs is not None: attention_probs *= array_ops.expand_dims( math_ops.to_float(not_finished), 2) accumulated_attention_probs = (beam_state.accumulated_attention_probs + attention_probs) # Calculate the scores for each beam scores = _get_scores( log_probs=total_probs, sequence_lengths=new_prediction_lengths, length_penalty_weight=length_penalty_weight, coverage_penalty_weight=coverage_penalty_weight, finished=previously_finished, accumulated_attention_probs=accumulated_attention_probs) time = ops.convert_to_tensor(time, name="time") # During the first time step we only consider the initial beam scores_flat = array_ops.reshape(scores, [batch_size, -1]) # Pick the next beams according to the specified successors function next_beam_size = ops.convert_to_tensor(beam_width, dtype=dtypes.int32, name="beam_width") next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size) next_beam_scores.set_shape([static_batch_size, beam_width]) word_indices.set_shape([static_batch_size, beam_width]) # Pick out the probs, beam_ids, and states according to the chosen predictions next_beam_probs = _tensor_gather_helper(gather_indices=word_indices, gather_from=total_probs, batch_size=batch_size, range_size=beam_width * vocab_size, gather_shape=[-1], name="next_beam_probs") # Note: just doing the following # math_ops.to_int32(word_indices % vocab_size, # name="next_beam_word_ids") # would be a lot cleaner but for reasons unclear, that hides the results of # the op which prevents capturing it with tfdbg debug ops. raw_next_word_ids = math_ops.mod(word_indices, vocab_size, name="next_beam_word_ids") next_word_ids = math_ops.to_int32(raw_next_word_ids) next_beam_ids = math_ops.to_int32(word_indices / vocab_size, name="next_beam_parent_ids") # Append new ids to current predictions previously_finished = _tensor_gather_helper( gather_indices=next_beam_ids, gather_from=previously_finished, batch_size=batch_size, range_size=beam_width, gather_shape=[-1]) next_finished = math_ops.logical_or(previously_finished, math_ops.equal(next_word_ids, end_token), name="next_beam_finished") # Calculate the length of the next predictions. # 1. Finished beams remain unchanged. # 2. Beams that are now finished (EOS predicted) have their length # increased by 1. # 3. Beams that are not yet finished have their length increased by 1. lengths_to_add = math_ops.to_int64( math_ops.logical_not(previously_finished)) next_prediction_len = _tensor_gather_helper(gather_indices=next_beam_ids, gather_from=beam_state.lengths, batch_size=batch_size, range_size=beam_width, gather_shape=[-1]) next_prediction_len += lengths_to_add next_accumulated_attention_probs = () if accumulated_attention_probs is not None: next_accumulated_attention_probs = _tensor_gather_helper( gather_indices=next_beam_ids, gather_from=accumulated_attention_probs, batch_size=batch_size, range_size=beam_width, gather_shape=[batch_size * beam_width, -1], name="next_accumulated_attention_probs") # Pick out the cell_states according to the next_beam_ids. We use a # different gather_shape here because the cell_state tensors, i.e. # the tensors that would be gathered from, all have dimension # greater than two and we need to preserve those dimensions. # pylint: disable=g-long-lambda next_cell_state = nest.map_structure( lambda gather_from: _maybe_tensor_gather_helper( gather_indices=next_beam_ids, gather_from=gather_from, batch_size=batch_size, range_size=beam_width, gather_shape=[batch_size * beam_width, -1]), next_cell_state) # pylint: enable=g-long-lambda next_state = beam_search_decoder.BeamSearchDecoderState( cell_state=next_cell_state, log_probs=next_beam_probs, lengths=next_prediction_len, finished=next_finished, accumulated_attention_probs=next_accumulated_attention_probs) output = beam_search_decoder.BeamSearchDecoderOutput( scores=next_beam_scores, predicted_ids=next_word_ids, parent_ids=next_beam_ids) return output, next_state