def test_step(self): dummy_cell_state = tf.zeros([self.batch_size, self.beam_width]) beam_state = beam_search_decoder.BeamSearchDecoderState( cell_state=dummy_cell_state, log_probs=tf.nn.log_softmax( tf.ones([self.batch_size, self.beam_width])), lengths=tf.constant(2, shape=[self.batch_size, self.beam_width], dtype=tf.int64), finished=tf.zeros([self.batch_size, self.beam_width], dtype=tf.bool), accumulated_attention_probs=(), ) logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], 0.0001) logits_[0, 0, 2] = 1.9 logits_[0, 0, 3] = 2.1 logits_[0, 1, 3] = 3.1 logits_[0, 1, 4] = 0.9 logits_[1, 0, 1] = 0.5 logits_[1, 1, 2] = 2.7 logits_[1, 2, 2] = 10.0 logits_[1, 2, 3] = 0.2 logits = tf.convert_to_tensor(logits_, dtype=tf.float32) log_probs = tf.nn.log_softmax(logits) outputs, next_beam_state = beam_search_decoder._beam_search_step( time=2, logits=logits, next_cell_state=dummy_cell_state, beam_state=beam_state, batch_size=tf.convert_to_tensor(self.batch_size), beam_width=self.beam_width, end_token=self.end_token, length_penalty_weight=self.length_penalty_weight, coverage_penalty_weight=self.coverage_penalty_weight, ) with self.cached_session() as sess: outputs_, next_state_, state_, log_probs_ = sess.run( [outputs, next_beam_state, beam_state, log_probs]) self.assertAllEqual(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]]) self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]]) self.assertAllEqual(next_state_.lengths, [[3, 3, 3], [3, 3, 3]]) self.assertAllEqual(next_state_.finished, [[False, False, False], [False, False, False]]) expected_log_probs = [] expected_log_probs.append(state_.log_probs[0][[1, 0, 0]]) expected_log_probs.append(state_.log_probs[1][[2, 1, 0]]) # 0 --> 1 expected_log_probs[0][0] += log_probs_[0, 1, 3] expected_log_probs[0][1] += log_probs_[0, 0, 3] expected_log_probs[0][2] += log_probs_[0, 0, 2] expected_log_probs[1][0] += log_probs_[1, 2, 2] expected_log_probs[1][1] += log_probs_[1, 1, 2] expected_log_probs[1][2] += log_probs_[1, 0, 1] self.assertAllEqual(next_state_.log_probs, expected_log_probs)
def test_step_with_eos(self): dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) beam_state = beam_search_decoder.BeamSearchDecoderState( cell_state=dummy_cell_state, log_probs=nn_ops.log_softmax( array_ops.ones([self.batch_size, self.beam_width])), lengths=ops.convert_to_tensor([[2, 1, 2], [2, 2, 1]], dtype=dtypes.int64), finished=ops.convert_to_tensor( [[False, True, False], [False, False, True]], dtype=dtypes.bool), accumulated_attention_probs=()) logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], 0.0001) logits_[0, 0, 2] = 1.9 logits_[0, 0, 3] = 2.1 logits_[0, 1, 3] = 3.1 logits_[0, 1, 4] = 0.9 logits_[1, 0, 1] = 0.5 logits_[1, 1, 2] = 5.7 # why does this not work when it's 2.7? logits_[1, 2, 2] = 1.0 logits_[1, 2, 3] = 0.2 logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32) log_probs = nn_ops.log_softmax(logits) outputs, next_beam_state = beam_search_decoder._beam_search_step( time=2, logits=logits, next_cell_state=dummy_cell_state, beam_state=beam_state, batch_size=ops.convert_to_tensor(self.batch_size), beam_width=self.beam_width, end_token=self.end_token, length_penalty_weight=self.length_penalty_weight, coverage_penalty_weight=self.coverage_penalty_weight) with self.cached_session() as sess: outputs_, next_state_, state_, log_probs_ = sess.run( [outputs, next_beam_state, beam_state, log_probs]) self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]]) self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]]) self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]]) self.assertAllEqual(next_state_.finished, [[True, False, False], [False, True, False]]) expected_log_probs = [] expected_log_probs.append(state_.log_probs[0][[1, 0, 0]]) expected_log_probs.append(state_.log_probs[1][[1, 2, 0]]) expected_log_probs[0][1] += log_probs_[0, 0, 3] expected_log_probs[0][2] += log_probs_[0, 0, 2] expected_log_probs[1][0] += log_probs_[1, 1, 2] expected_log_probs[1][2] += log_probs_[1, 0, 1] self.assertAllEqual(next_state_.log_probs, expected_log_probs)
def step( self, time: tf.Tensor, inputs: list[tf.Tensor], state: tf.Tensor, training: Optional[bool] = None, name: Optional[str] = None, ) -> list[tf.Tensor]: """Perform a decoding step. Args: time: scalar `int32` tensor. inputs: A (structure of) input tensors. state: A (structure of) state tensors and TensorArrays. training: Python boolean. Indicates whether the layer should behave in training mode or in inference mode. Only relevant when `dropout` or `recurrent_dropout` is used. name: Name scope for any created operations. Returns: `(outputs, next_state, next_inputs, finished)`. """ with tf.name_scope(name or "BeamSearchDecoderStep"): cell_state = state.cell_state cell_outputs, next_cell_state = self._cell(inputs, cell_state, training=training) cell_outputs = tf.nest.map_structure( lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs) if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) beam_search_output, beam_search_state = _beam_search_step( time=time, logits=cell_outputs, next_cell_state=next_cell_state, beam_state=state, batch_size=self._batch_size, beam_width=self._beam_width, end_token=self._end_token, length_penalty_weight=self._length_penalty_weight, coverage_penalty_weight=self._coverage_penalty_weight, output_all_scores=self._output_all_scores, ) finished = beam_search_state.finished sample_ids = beam_search_output.predicted_ids next_inputs = self._next_inputs( inputs, self._merge_batch_beams(sample_ids, s=sample_ids.shape[2:])) return [beam_search_output, beam_search_state, next_inputs, finished]
def test_step(self): def get_probs(): """this simulates the initialize method in BeamSearchDecoder.""" log_prob_mask = tf.one_hot(tf.zeros([self.batch_size], dtype=tf.int32), depth=self.beam_width, on_value=True, off_value=False, dtype=tf.bool) log_prob_zeros = tf.zeros([self.batch_size, self.beam_width], dtype=tf.float32) log_prob_neg_inf = tf.ones([self.batch_size, self.beam_width], dtype=tf.float32) * -np.Inf log_probs = tf.where(log_prob_mask, log_prob_zeros, log_prob_neg_inf) return log_probs log_probs = get_probs() dummy_cell_state = tf.zeros([self.batch_size, self.beam_width]) # pylint: disable=invalid-name _finished = tf.one_hot(tf.zeros([self.batch_size], dtype=tf.int32), depth=self.beam_width, on_value=False, off_value=True, dtype=tf.bool) _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64) _lengths[:, 0] = 2 _lengths = tf.constant(_lengths, dtype=tf.int64) beam_state = beam_search_decoder.BeamSearchDecoderState( cell_state=dummy_cell_state, log_probs=log_probs, lengths=_lengths, finished=_finished, accumulated_attention_probs=()) logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], 0.0001) logits_[0, 0, 2] = 1.9 logits_[0, 0, 3] = 2.1 logits_[0, 1, 3] = 3.1 logits_[0, 1, 4] = 0.9 logits_[1, 0, 1] = 0.5 logits_[1, 1, 2] = 2.7 logits_[1, 2, 2] = 10.0 logits_[1, 2, 3] = 0.2 logits = tf.constant(logits_, dtype=tf.float32) log_probs = tf.nn.log_softmax(logits) outputs, next_beam_state = beam_search_decoder._beam_search_step( time=2, logits=logits, next_cell_state=dummy_cell_state, beam_state=beam_state, batch_size=tf.convert_to_tensor(self.batch_size), beam_width=self.beam_width, end_token=self.end_token, length_penalty_weight=self.length_penalty_weight, coverage_penalty_weight=self.coverage_penalty_weight) with self.cached_session() as sess: outputs_, next_state_, _, _ = sess.run( [outputs, next_beam_state, beam_state, log_probs]) self.assertEqual(outputs_.predicted_ids[0, 0], 3) self.assertEqual(outputs_.predicted_ids[0, 1], 2) self.assertEqual(outputs_.predicted_ids[1, 0], 1) neg_inf = -np.Inf self.assertAllEqual( next_state_.log_probs[:, -3:], [[neg_inf, neg_inf, neg_inf], [neg_inf, neg_inf, neg_inf]]) self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True) self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True) self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]])
def test_large_beam_step(): batch_size = 2 beam_width = 8 vocab_size = 5 end_token = 0 length_penalty_weight = 0.6 coverage_penalty_weight = 0.0 def get_probs(): """this simulates the initialize method in BeamSearchDecoder.""" log_prob_mask = tf.one_hot( tf.zeros([batch_size], dtype=tf.int32), depth=beam_width, on_value=True, off_value=False, dtype=tf.bool, ) log_prob_zeros = tf.zeros([batch_size, beam_width], dtype=tf.float32) log_prob_neg_inf = tf.ones([batch_size, beam_width], dtype=tf.float32) * -np.Inf log_probs = tf.where(log_prob_mask, log_prob_zeros, log_prob_neg_inf) return log_probs log_probs = get_probs() dummy_cell_state = tf.zeros([batch_size, beam_width]) _finished = tf.one_hot( tf.zeros([batch_size], dtype=tf.int32), depth=beam_width, on_value=False, off_value=True, dtype=tf.bool, ) _lengths = np.zeros([batch_size, beam_width], dtype=np.int64) _lengths[:, 0] = 2 _lengths = tf.constant(_lengths, dtype=tf.int64) beam_state = beam_search_decoder.BeamSearchDecoderState( cell_state=dummy_cell_state, log_probs=log_probs, lengths=_lengths, finished=_finished, accumulated_attention_probs=(), ) logits_ = np.full([batch_size, beam_width, vocab_size], 0.0001) logits_[0, 0, 2] = 1.9 logits_[0, 0, 3] = 2.1 logits_[0, 1, 3] = 3.1 logits_[0, 1, 4] = 0.9 logits_[1, 0, 1] = 0.5 logits_[1, 1, 2] = 2.7 logits_[1, 2, 2] = 10.0 logits_[1, 2, 3] = 0.2 logits = tf.constant(logits_, dtype=tf.float32) log_probs = tf.nn.log_softmax(logits) outputs, next_beam_state = beam_search_decoder._beam_search_step( time=2, logits=logits, next_cell_state=dummy_cell_state, beam_state=beam_state, batch_size=tf.convert_to_tensor(batch_size), beam_width=beam_width, end_token=end_token, length_penalty_weight=length_penalty_weight, coverage_penalty_weight=coverage_penalty_weight, ) outputs_, next_state_ = [outputs, next_beam_state] assert outputs_.predicted_ids[0, 0] == 3 assert outputs_.predicted_ids[0, 1] == 2 assert outputs_.predicted_ids[1, 0] == 1 neg_inf = -np.Inf np.testing.assert_equal( next_state_.log_probs[:, -3:].numpy(), np.asanyarray([[neg_inf, neg_inf, neg_inf], [neg_inf, neg_inf, neg_inf]]), ) np.testing.assert_equal( np.asanyarray(next_state_.log_probs[:, :-3] > neg_inf), True) np.testing.assert_equal(np.asanyarray(next_state_.lengths[:, :-3] > 0), True) np.testing.assert_equal(next_state_.lengths[:, -3:].numpy(), np.asanyarray([[0, 0, 0], [0, 0, 0]]))
def test_step_with_eos(): batch_size = 2 beam_width = 3 vocab_size = 5 end_token = 0 length_penalty_weight = 0.6 coverage_penalty_weight = 0.0 dummy_cell_state = tf.zeros([batch_size, beam_width]) beam_state = beam_search_decoder.BeamSearchDecoderState( cell_state=dummy_cell_state, log_probs=tf.nn.log_softmax(tf.ones([batch_size, beam_width])), lengths=tf.convert_to_tensor([[2, 1, 2], [2, 2, 1]], dtype=tf.int64), finished=tf.convert_to_tensor( [[False, True, False], [False, False, True]], dtype=tf.bool), accumulated_attention_probs=(), ) logits_ = np.full([batch_size, beam_width, vocab_size], 0.0001) logits_[0, 0, 2] = 1.9 logits_[0, 0, 3] = 2.1 logits_[0, 1, 3] = 3.1 logits_[0, 1, 4] = 0.9 logits_[1, 0, 1] = 0.5 logits_[1, 1, 2] = 5.7 # why does this not work when it's 2.7? logits_[1, 2, 2] = 1.0 logits_[1, 2, 3] = 0.2 logits = tf.convert_to_tensor(logits_, dtype=tf.float32) log_probs = tf.nn.log_softmax(logits) outputs, next_beam_state = beam_search_decoder._beam_search_step( time=2, logits=logits, next_cell_state=dummy_cell_state, beam_state=beam_state, batch_size=tf.convert_to_tensor(batch_size), beam_width=beam_width, end_token=end_token, length_penalty_weight=length_penalty_weight, coverage_penalty_weight=coverage_penalty_weight, ) outputs_, next_state_, state_, log_probs_ = [ outputs, next_beam_state, beam_state, log_probs, ] np.testing.assert_equal(outputs_.parent_ids.numpy(), np.asanyarray([[1, 0, 0], [1, 2, 0]])) np.testing.assert_equal(outputs_.predicted_ids.numpy(), np.asanyarray([[0, 3, 2], [2, 0, 1]])) np.testing.assert_equal(next_state_.lengths.numpy(), np.asanyarray([[1, 3, 3], [3, 1, 3]])) np.testing.assert_equal( next_state_.finished.numpy(), np.asanyarray([[True, False, False], [False, True, False]]), ) expected_log_probs = [] expected_log_probs.append(state_.log_probs[0].numpy()) expected_log_probs.append(state_.log_probs[1].numpy()) expected_log_probs[0][1] += log_probs_[0, 0, 3] expected_log_probs[0][2] += log_probs_[0, 0, 2] expected_log_probs[1][0] += log_probs_[1, 1, 2] expected_log_probs[1][2] += log_probs_[1, 0, 1] np.testing.assert_equal(next_state_.log_probs.numpy(), np.asanyarray(expected_log_probs))