def test_step(self): dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) beam_state = beam_search_decoder.BeamSearchDecoderState( cell_state=dummy_cell_state, log_probs=nn_ops.log_softmax( array_ops.ones([self.batch_size, self.beam_width])), lengths=constant_op.constant( 2, shape=[self.batch_size, self.beam_width], dtype=dtypes.int32), finished=array_ops.zeros([self.batch_size, self.beam_width], dtype=dtypes.bool)) logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], 0.0001) logits_[0, 0, 2] = 1.9 logits_[0, 0, 3] = 2.1 logits_[0, 1, 3] = 3.1 logits_[0, 1, 4] = 0.9 logits_[1, 0, 1] = 0.5 logits_[1, 1, 2] = 2.7 logits_[1, 2, 2] = 10.0 logits_[1, 2, 3] = 0.2 logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32) log_probs = nn_ops.log_softmax(logits) outputs, next_beam_state = beam_search_decoder._beam_search_step( time=2, logits=logits, beam_state=beam_state, batch_size=ops.convert_to_tensor(self.batch_size), beam_width=self.beam_width, end_token=self.end_token, length_penalty_weight=self.length_penalty_weight) with self.test_session() as sess: outputs_, next_state_, state_, log_probs_ = sess.run( [outputs, next_beam_state, beam_state, log_probs]) np.testing.assert_array_equal(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]]) np.testing.assert_array_equal(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]]) np.testing.assert_array_equal(next_state_.lengths, [[3, 3, 3], [3, 3, 3]]) np.testing.assert_array_equal( next_state_.finished, [[False, False, False], [False, False, False]]) expected_log_probs = [] expected_log_probs.append(state_.log_probs[0][[1, 0, 0]]) expected_log_probs.append(state_.log_probs[1][[2, 1, 0]]) # 0 --> 1 expected_log_probs[0][0] += log_probs_[0, 1, 3] expected_log_probs[0][1] += log_probs_[0, 0, 3] expected_log_probs[0][2] += log_probs_[0, 0, 2] expected_log_probs[1][0] += log_probs_[1, 2, 2] expected_log_probs[1][1] += log_probs_[1, 1, 2] expected_log_probs[1][2] += log_probs_[1, 0, 1] np.testing.assert_array_equal(next_state_.log_probs, expected_log_probs)
def step(self, time, inputs, state, name=None): batch_size = self._batch_size beam_width = self._beam_width end_token = self._end_token length_penalty_weight = self._length_penalty_weight with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)): cell_state = state.cell_state inputs = nest.map_structure( lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs) cell_state = nest.map_structure( self._maybe_merge_batch_beams, cell_state, self._cell.state_size) cell_outputs, next_cell_state = self._cell(inputs, cell_state) cell_outputs = nest.map_structure( lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs) next_cell_state = nest.map_structure( self._maybe_split_batch_beams, next_cell_state, self._cell.state_size) ### context vector K = next_cell_state Q = self.encoder_ouputs V = self.encoder_ouputs outputs = tf.matmul(Q, tf.transpose(K, [0, 2, 1])) # bxtxc bxcxbeam => bxtxbeam attens = tf.nn.softmax(outputs, axis=1) # bxtxbeam context_vec = tf.expand_dims(attens, 3)*tf.expand_dims(V, 2) # bxtxbeamx1 bxtx1xc => bxtxbeamxc context_vec = tf.reduce_sum(context_vec, axis=1) # bxtxbeamxc => bxbeamxc ### end context vector ### cell_outputs vector cell_outputs = array_ops.concat([cell_outputs, context_vec], -1) ### end cell_outputs vector if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) beam_search_output, beam_search_state = _beam_search_step( time=time, logits=cell_outputs, next_cell_state=next_cell_state, beam_state=state, batch_size=batch_size, beam_width=beam_width, end_token=end_token, length_penalty_weight=length_penalty_weight) finished = beam_search_state.finished sample_ids = beam_search_output.predicted_ids next_inputs = control_flow_ops.cond( math_ops.reduce_all(finished), lambda: self._start_inputs, lambda: self._embedding_fn(sample_ids)) ### next_inputs vector next_inputs = array_ops.concat([next_inputs, self.z], -1) # bxbeamx[e+c+c]=bx5x640 # attens = tf.transpose(attens, [0,2,1]) # bxbeamxt # beam_search_output = BeamSearchDecoderOutput(attens, beam_search_output[1], beam_search_output[2], beam_search_output[3]) ### next_inputs vector return (beam_search_output, beam_search_state, next_inputs, finished)
def test_step(self): dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) beam_state = beam_search_decoder.BeamSearchDecoderState( cell_state=dummy_cell_state, log_probs=nn_ops.log_softmax( array_ops.ones([self.batch_size, self.beam_width])), lengths=constant_op.constant( 2, shape=[self.batch_size, self.beam_width], dtype=dtypes.int64), finished=array_ops.zeros( [self.batch_size, self.beam_width], dtype=dtypes.bool), accumulated_attention_probs=()) logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], 0.0001) logits_[0, 0, 2] = 1.9 logits_[0, 0, 3] = 2.1 logits_[0, 1, 3] = 3.1 logits_[0, 1, 4] = 0.9 logits_[1, 0, 1] = 0.5 logits_[1, 1, 2] = 2.7 logits_[1, 2, 2] = 10.0 logits_[1, 2, 3] = 0.2 logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32) log_probs = nn_ops.log_softmax(logits) outputs, next_beam_state = beam_search_decoder._beam_search_step( time=2, logits=logits, next_cell_state=dummy_cell_state, beam_state=beam_state, batch_size=ops.convert_to_tensor(self.batch_size), beam_width=self.beam_width, end_token=self.end_token, length_penalty_weight=self.length_penalty_weight, coverage_penalty_weight=self.coverage_penalty_weight) with self.cached_session() as sess: outputs_, next_state_, state_, log_probs_ = sess.run( [outputs, next_beam_state, beam_state, log_probs]) self.assertAllEqual(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]]) self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]]) self.assertAllEqual(next_state_.lengths, [[3, 3, 3], [3, 3, 3]]) self.assertAllEqual(next_state_.finished, [[False, False, False], [False, False, False]]) expected_log_probs = [] expected_log_probs.append(state_.log_probs[0][[1, 0, 0]]) expected_log_probs.append(state_.log_probs[1][[2, 1, 0]]) # 0 --> 1 expected_log_probs[0][0] += log_probs_[0, 1, 3] expected_log_probs[0][1] += log_probs_[0, 0, 3] expected_log_probs[0][2] += log_probs_[0, 0, 2] expected_log_probs[1][0] += log_probs_[1, 2, 2] expected_log_probs[1][1] += log_probs_[1, 1, 2] expected_log_probs[1][2] += log_probs_[1, 0, 1] self.assertAllEqual(next_state_.log_probs, expected_log_probs)
def test_step_with_eos(self): dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) beam_state = beam_search_decoder.BeamSearchDecoderState( cell_state=dummy_cell_state, log_probs=nn_ops.log_softmax( array_ops.ones([self.batch_size, self.beam_width])), lengths=ops.convert_to_tensor([[2, 1, 2], [2, 2, 1]], dtype=dtypes.int64), finished=ops.convert_to_tensor( [[False, True, False], [False, False, True]], dtype=dtypes.bool), accumulated_attention_probs=()) logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], 0.0001) logits_[0, 0, 2] = 1.9 logits_[0, 0, 3] = 2.1 logits_[0, 1, 3] = 3.1 logits_[0, 1, 4] = 0.9 logits_[1, 0, 1] = 0.5 logits_[1, 1, 2] = 5.7 # why does this not work when it's 2.7? logits_[1, 2, 2] = 1.0 logits_[1, 2, 3] = 0.2 logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32) log_probs = nn_ops.log_softmax(logits) outputs, next_beam_state = beam_search_decoder._beam_search_step( time=2, logits=logits, next_cell_state=dummy_cell_state, beam_state=beam_state, batch_size=ops.convert_to_tensor(self.batch_size), beam_width=self.beam_width, end_token=self.end_token, length_penalty_weight=self.length_penalty_weight, coverage_penalty_weight=self.coverage_penalty_weight) with self.cached_session() as sess: outputs_, next_state_, state_, log_probs_ = sess.run( [outputs, next_beam_state, beam_state, log_probs]) self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]]) self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]]) self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]]) self.assertAllEqual(next_state_.finished, [[True, False, False], [False, True, False]]) expected_log_probs = [] expected_log_probs.append(state_.log_probs[0][[1, 0, 0]]) expected_log_probs.append(state_.log_probs[1][[1, 2, 0]]) expected_log_probs[0][1] += log_probs_[0, 0, 3] expected_log_probs[0][2] += log_probs_[0, 0, 2] expected_log_probs[1][0] += log_probs_[1, 1, 2] expected_log_probs[1][2] += log_probs_[1, 0, 1] self.assertAllEqual(next_state_.log_probs, expected_log_probs)
def step(self, time, inputs, state, name=None): """Perform a decoding step. Args: time: scalar `int32` tensor. inputs: A (structure of) input tensors. state: A (structure of) state tensors and TensorArrays. name: Name scope for any created operations. Returns: `(outputs, next_state, next_inputs, finished)`. """ batch_size = self._batch_size beam_width = self._beam_width end_token = self._end_token length_penalty_weight = self._length_penalty_weight with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)): cell_state = state.cell_state inputs = nest.map_structure( lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs) cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state, self._cell.state_size) cell_outputs, next_cell_state = self._cell(inputs, cell_state) cell_outputs = nest.map_structure( lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs) next_cell_state = nest.map_structure(self._maybe_split_batch_beams, next_cell_state, self._cell.state_size) if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) beam_search_output, beam_search_state = _beam_search_step( time=time, logits=cell_outputs, next_cell_state=next_cell_state, beam_state=state, batch_size=batch_size, beam_width=beam_width, end_token=end_token, length_penalty_weight=length_penalty_weight) finished = beam_search_state.finished sample_ids = beam_search_output.predicted_ids next_inputs = control_flow_ops.cond( math_ops.reduce_all(finished), lambda: self._start_inputs, lambda: tf.concat( [self._embedding_fn(sample_ids), self.cnn_inputs], 2)) return (beam_search_output, beam_search_state, next_inputs, finished)
def test_step_with_eos(self): dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) beam_state = beam_search_decoder.BeamSearchDecoderState( cell_state=dummy_cell_state, log_probs=nn_ops.log_softmax( array_ops.ones([self.batch_size, self.beam_width])), lengths=ops.convert_to_tensor( [[2, 1, 2], [2, 2, 1]], dtype=dtypes.int32), finished=ops.convert_to_tensor( [[False, True, False], [False, False, True]], dtype=dtypes.bool)) logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], 0.0001) logits_[0, 0, 2] = 1.9 logits_[0, 0, 3] = 2.1 logits_[0, 1, 3] = 3.1 logits_[0, 1, 4] = 0.9 logits_[1, 0, 1] = 0.5 logits_[1, 1, 2] = 5.7 # why does this not work when it's 2.7? logits_[1, 2, 2] = 1.0 logits_[1, 2, 3] = 0.2 logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32) log_probs = nn_ops.log_softmax(logits) outputs, next_beam_state = beam_search_decoder._beam_search_step( time=2, logits=logits, beam_state=beam_state, batch_size=ops.convert_to_tensor(self.batch_size), beam_width=self.beam_width, end_token=self.end_token, length_penalty_weight=self.length_penalty_weight) with self.test_session() as sess: outputs_, next_state_, state_, log_probs_ = sess.run( [outputs, next_beam_state, beam_state, log_probs]) np.testing.assert_array_equal(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]]) np.testing.assert_array_equal(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]]) np.testing.assert_array_equal(next_state_.lengths, [[1, 3, 3], [3, 1, 3]]) np.testing.assert_array_equal(next_state_.finished, [[True, False, False], [False, True, False]]) expected_log_probs = [] expected_log_probs.append(state_.log_probs[0][[1, 0, 0]]) expected_log_probs.append(state_.log_probs[1][[1, 2, 0]]) expected_log_probs[0][1] += log_probs_[0, 0, 3] expected_log_probs[0][2] += log_probs_[0, 0, 2] expected_log_probs[1][0] += log_probs_[1, 1, 2] expected_log_probs[1][2] += log_probs_[1, 0, 1] np.testing.assert_array_equal(next_state_.log_probs, expected_log_probs)
def step(self, time, inputs, state, name=None): batch_size = self._batch_size beam_width = self._beam_width end_token = self._end_token length_penalty_weight = self._length_penalty_weight with tf.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)): cell_state = state.cell_state inputs = nest.map_structure( lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs) cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state, self._cell.state_size) cell_outputs, next_cell_state = self._cell(inputs, cell_state) cell_outputs = nest.map_structure( lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs) next_cell_state = nest.map_structure(self._maybe_split_batch_beams, next_cell_state, self._cell.state_size) if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) beam_search_output, beam_search_state = _beam_search_step( time=time, logits=cell_outputs, next_cell_state=next_cell_state, beam_state=state, batch_size=batch_size, beam_width=beam_width, end_token=end_token, length_penalty_weight=length_penalty_weight) finished = beam_search_state.finished sample_ids = beam_search_output.predicted_ids next_inputs = tf.cond(tf.reduce_all(finished), lambda: self._start_inputs, lambda: self._embedding_fn(sample_ids)) next_inputs = tf.concat([next_inputs, self.z], -1) return (beam_search_output, beam_search_state, next_inputs, finished)
def test_step(self): def get_probs(): """this simulates the initialize method in BeamSearchDecoder.""" log_prob_mask = array_ops.one_hot(array_ops.zeros( [self.batch_size], dtype=dtypes.int32), depth=self.beam_width, on_value=True, off_value=False, dtype=dtypes.bool) log_prob_zeros = array_ops.zeros( [self.batch_size, self.beam_width], dtype=dtypes.float32) log_prob_neg_inf = array_ops.ones( [self.batch_size, self.beam_width], dtype=dtypes.float32) * -np.Inf log_probs = array_ops.where(log_prob_mask, log_prob_zeros, log_prob_neg_inf) return log_probs log_probs = get_probs() dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) # pylint: disable=invalid-name _finished = array_ops.one_hot(array_ops.zeros([self.batch_size], dtype=dtypes.int32), depth=self.beam_width, on_value=False, off_value=True, dtype=dtypes.bool) _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64) _lengths[:, 0] = 2 _lengths = constant_op.constant(_lengths, dtype=dtypes.int64) beam_state = beam_search_decoder.BeamSearchDecoderState( cell_state=dummy_cell_state, log_probs=log_probs, lengths=_lengths, finished=_finished) logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], 0.0001) logits_[0, 0, 2] = 1.9 logits_[0, 0, 3] = 2.1 logits_[0, 1, 3] = 3.1 logits_[0, 1, 4] = 0.9 logits_[1, 0, 1] = 0.5 logits_[1, 1, 2] = 2.7 logits_[1, 2, 2] = 10.0 logits_[1, 2, 3] = 0.2 logits = constant_op.constant(logits_, dtype=dtypes.float32) log_probs = nn_ops.log_softmax(logits) outputs, next_beam_state = beam_search_decoder._beam_search_step( time=2, logits=logits, next_cell_state=dummy_cell_state, beam_state=beam_state, batch_size=ops.convert_to_tensor(self.batch_size), beam_width=self.beam_width, end_token=self.end_token, length_penalty_weight=self.length_penalty_weight) with self.test_session() as sess: outputs_, next_state_, _, _ = sess.run( [outputs, next_beam_state, beam_state, log_probs]) self.assertEqual(outputs_.predicted_ids[0, 0], 3) self.assertEqual(outputs_.predicted_ids[0, 1], 2) self.assertEqual(outputs_.predicted_ids[1, 0], 1) neg_inf = -np.Inf self.assertAllEqual( next_state_.log_probs[:, -3:], [[neg_inf, neg_inf, neg_inf], [neg_inf, neg_inf, neg_inf]]) self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True) self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True) self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]])
def step(self, time, inputs, state, name=None): batch_size = self._batch_size beam_width = self._beam_width end_token = self._end_token length_penalty_weight = self._length_penalty_weight with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)): cell_state = state.cell_state inputs = nest.map_structure( lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs) cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state, self._cell.state_size) cell_outputs, next_cell_state = self._cell(inputs, cell_state) cell_outputs = nest.map_structure( lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs) next_cell_state = nest.map_structure(self._maybe_split_batch_beams, next_cell_state, self._cell.state_size) if self._output_layer is not None: # My modification if isinstance(self._output_layer, taware_layer.JointDenseLayer): reshaped_inputs = tf.reshape( inputs, [-1, beam_width, inputs.shape[-1]]) if self._current_context is not None: msg_attention, _ = tf.split(self._current_context, num_or_size_splits=2, axis=1) msg_attention = tf.reshape( msg_attention, [-1, beam_width, msg_attention.shape[-1]]) cell_outputs = self._output_layer( cell_outputs, input=reshaped_inputs, context=msg_attention) else: cell_outputs = self._output_layer( cell_outputs, input=reshaped_inputs) else: cell_outputs = self._output_layer(cell_outputs) beam_search_output, beam_search_state = _beam_search_step( time=time, logits=cell_outputs, next_cell_state=next_cell_state, beam_state=state, batch_size=batch_size, beam_width=beam_width, end_token=end_token, length_penalty_weight=length_penalty_weight, coverage_penalty_weight=0.0) finished = beam_search_state.finished sample_ids = beam_search_output.predicted_ids next_inputs = control_flow_ops.cond( math_ops.reduce_all(finished), lambda: self._start_inputs, lambda: self._embedding_fn(sample_ids)) # My modification self._current_context = cell_state.attention return (beam_search_output, beam_search_state, next_inputs, finished)
def test_step(self): def get_probs(): """this simulates the initialize method in BeamSearchDecoder.""" log_prob_mask = array_ops.one_hot( array_ops.zeros([self.batch_size], dtype=dtypes.int32), depth=self.beam_width, on_value=True, off_value=False, dtype=dtypes.bool) log_prob_zeros = array_ops.zeros( [self.batch_size, self.beam_width], dtype=dtypes.float32) log_prob_neg_inf = array_ops.ones( [self.batch_size, self.beam_width], dtype=dtypes.float32) * -np.Inf log_probs = array_ops.where(log_prob_mask, log_prob_zeros, log_prob_neg_inf) return log_probs log_probs = get_probs() dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) # pylint: disable=invalid-name _finished = array_ops.one_hot( array_ops.zeros([self.batch_size], dtype=dtypes.int32), depth=self.beam_width, on_value=False, off_value=True, dtype=dtypes.bool) _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64) _lengths[:, 0] = 2 _lengths = constant_op.constant(_lengths, dtype=dtypes.int64) beam_state = beam_search_decoder.BeamSearchDecoderState( cell_state=dummy_cell_state, log_probs=log_probs, lengths=_lengths, finished=_finished, accumulated_attention_probs=()) logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], 0.0001) logits_[0, 0, 2] = 1.9 logits_[0, 0, 3] = 2.1 logits_[0, 1, 3] = 3.1 logits_[0, 1, 4] = 0.9 logits_[1, 0, 1] = 0.5 logits_[1, 1, 2] = 2.7 logits_[1, 2, 2] = 10.0 logits_[1, 2, 3] = 0.2 logits = constant_op.constant(logits_, dtype=dtypes.float32) log_probs = nn_ops.log_softmax(logits) outputs, next_beam_state = beam_search_decoder._beam_search_step( time=2, logits=logits, next_cell_state=dummy_cell_state, beam_state=beam_state, batch_size=ops.convert_to_tensor(self.batch_size), beam_width=self.beam_width, end_token=self.end_token, length_penalty_weight=self.length_penalty_weight, coverage_penalty_weight=self.coverage_penalty_weight) with self.cached_session() as sess: outputs_, next_state_, _, _ = sess.run( [outputs, next_beam_state, beam_state, log_probs]) self.assertEqual(outputs_.predicted_ids[0, 0], 3) self.assertEqual(outputs_.predicted_ids[0, 1], 2) self.assertEqual(outputs_.predicted_ids[1, 0], 1) neg_inf = -np.Inf self.assertAllEqual( next_state_.log_probs[:, -3:], [[neg_inf, neg_inf, neg_inf], [neg_inf, neg_inf, neg_inf]]) self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True) self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True) self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]])
def step(self, time, inputs, state, name=None): """Perform a decoding step. Args: time: scalar `int32` tensor. inputs: A (structure of) input tensors. state: A (structure of) state tensors and TensorArrays. name: Name scope for any created operations. Returns: `(outputs, next_state, next_inputs, finished)`. """ batch_size = self._batch_size beam_width = self._beam_width end_token = self._end_token length_penalty_weight = self._length_penalty_weight with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)): cell_state = state.cell_state inputs = nest.map_structure( lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs) cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state, self._cell.state_size) cell_outputs, next_cell_state = self._cell(inputs, cell_state) # finished = tf.Print(state.finished, [state.finished, 'finished', time], summarize=100) # not_finished = tf.Print(not_finished, [not_finished, 'not_finished', time], summarize=100) # cell_state.last_choice shape = [batch_size * beam_width] next_choices = gen_array_ops.gather_v2(self.lookup_table, cell_state.last_choice, axis=0) not_finished = tf.not_equal(next_choices[:, 0], end_token) next_next_choices = gen_array_ops.gather_v2(self.lookup_table, next_choices[:, 0], axis=0) will_finish = tf.logical_and( not_finished, tf.equal(next_next_choices[:, 0], end_token)) def move(will_finish, last_choice, cell_outputs): # cell_outputs = tf.Print(cell_outputs, [cell_outputs, 'cell_outputs', time], summarize=1000) # will_finish = tf.Print(will_finish, [will_finish, 'will_finish', time], summarize=100) attention_score = self._step_method(last_choice) attention_score = attention_score + cell_outputs # final = tf.Print(final, [final, 'finalll', time], summarize=1000) return tf.where(will_finish, attention_score, cell_outputs) if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) # will_finish = tf.Print(will_finish, [will_finish, 'will_finish, beam_search', time], summarize=100) cell_outputs = tf.cond( tf.reduce_any(will_finish), false_fn=lambda: cell_outputs, true_fn=lambda: move(will_finish, cell_state.last_choice, cell_outputs)) if self.hie: cell_outputs = self._mask_outputs_by_lable( cell_outputs, cell_state.last_choice) # cell_state.last_choice shape = [batch_size*beam_width,] cell_outputs = nest.map_structure( lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs) next_cell_state = nest.map_structure(self._maybe_split_batch_beams, next_cell_state, self._cell.state_size) beam_search_output, beam_search_state = _beam_search_step( time=time, logits=cell_outputs, next_cell_state=next_cell_state, beam_state=state, batch_size=batch_size, beam_width=beam_width, end_token=end_token, length_penalty_weight=length_penalty_weight) finished = beam_search_state.finished # replace the father ids sample_ids = beam_search_output.predicted_ids next_cell_state = beam_search_state.cell_state next_cell_state = next_cell_state._replace(last_choice=sample_ids) beam_search_state = beam_search_state._replace( cell_state=next_cell_state) # sample_ids shape = [batch_size, beam_width] next_inputs = control_flow_ops.cond( math_ops.reduce_all(finished), lambda: self._start_inputs, lambda: self._embedding_fn(sample_ids)) return (beam_search_output, beam_search_state, next_inputs, finished)
def step(self, time, inputs, state, name=None): """Perform a decoding step. Args: time: scalar `int32` tensor. inputs: A (structure of) input tensors. state: A (structure of) state tensors and TensorArrays. name: Name scope for any created operations. Returns: `(outputs, next_state, next_inputs, finished)`. """ batch_size = self._batch_size beam_width = self._beam_width end_token = self._end_token length_penalty_weight = self._length_penalty_weight coverage_penalty_weight = self._coverage_penalty_weight with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)): cell_state = state.cell_state inputs = nest.map_structure( lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs) cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state, self._cell.state_size) cell_outputs, next_cell_state = self._cell(inputs, cell_state) cell_outputs = nest.map_structure( lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs) next_cell_state = nest.map_structure( self._maybe_split_batch_beams, next_cell_state, self._cell.state_size) if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) if self._shrink_vocab > 0: self._skip_tokens_decoding += list( range(self._shrink_vocab, cell_outputs.get_shape()[2])) self._skip_tokens_decoding = sorted( set(self._skip_tokens_decoding)) # Never skip _END token, no matter what if self._raw_end_token in self._skip_tokens_decoding: self._skip_tokens_decoding.remove(self._raw_end_token) # Assign least possible logit for given list of tokens to avoid those tokens while decoding if len(self._skip_tokens_decoding) > 0: token_num = cell_outputs.get_shape()[2] minimum_activation = tf.reduce_min(cell_outputs) - 1 blacklist = tf.sparse_to_dense( self._skip_tokens_decoding, output_shape=[cell_outputs.get_shape()[2]], sparse_values=0.0, default_value=1.0) cell_outputs = tf.add(tf.multiply( cell_outputs, blacklist), minimum_activation * (1 - blacklist)) beam_search_output, beam_search_state = beam_search_decoder._beam_search_step( time=time, logits=cell_outputs, next_cell_state=next_cell_state, beam_state=state, batch_size=batch_size, beam_width=beam_width, end_token=end_token, length_penalty_weight=length_penalty_weight, coverage_penalty_weight=coverage_penalty_weight) finished = beam_search_state.finished sample_ids = beam_search_output.predicted_ids next_inputs = control_flow_ops.cond( math_ops.reduce_all(finished), lambda: self._start_inputs, lambda: self._embedding_fn(sample_ids)) return (beam_search_output, beam_search_state, next_inputs, finished)