def finalize(self, outputs, final_state, sequence_lengths): """Finalize and return the predicted_ids. Args: outputs: An instance of BeamSearchDecoderOutput. final_state: An instance of BeamSearchDecoderState. Passed through to the output. sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`. The sequence lengths determined for each beam during decode. **NOTE** These are ignored; the updated sequence lengths are stored in `final_state.lengths`. Returns: outputs: An instance of `FinalBeamSearchDecoderOutput` where the predicted_ids are the result of calling _gather_tree. final_state: The same input instance of `BeamSearchDecoderState`. """ del sequence_lengths # Get max_sequence_length across all beams for each batch. max_sequence_lengths = math_ops.to_int32( math_ops.reduce_max(final_state.lengths, axis=1)) predicted_ids = beam_search_ops.gather_tree( outputs.predicted_ids, outputs.parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=self._end_token) outputs = FinalBeamSearchDecoderOutput( beam_search_decoder_output=outputs, predicted_ids=predicted_ids) return outputs, final_state
def test_gather_tree(self): # (max_time = 3, batch_size = 2, beam_width = 3) # create (batch_size, max_time, beam_width) matrix and transpose it predicted_ids = np.array( [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[2, 3, 4], [5, 6, 7], [8, 9, 10]]], dtype=np.int32).transpose([1, 0, 2]) parent_ids = np.array( [[[0, 0, 0], [0, 1, 1], [2, 1, 2]], [[0, 0, 0], [1, 2, 0], [2, 1, 1]]], dtype=np.int32).transpose([1, 0, 2]) # sequence_lengths is shaped (batch_size = 3) max_sequence_lengths = [3, 3] expected_result = np.array([[[2, 2, 2], [6, 5, 6], [7, 8, 9]], [[2, 4, 4], [7, 6, 6], [8, 9, 10]]]).transpose([1, 0, 2]) res = beam_search_ops.gather_tree( predicted_ids, parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=11) with self.cached_session() as sess: res_ = sess.run(res) self.assertAllEqual(expected_result, res_)
def gather_tree_from_array(t, parent_ids, sequence_length): """Calculates the full beams for `TensorArray`s. Args: t: A stacked `TensorArray` of size `max_time` that contains `Tensor`s of shape `[batch_size, beam_width, s]` or `[batch_size * beam_width, s]` where `s` is the depth shape. parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`. sequence_length: The sequence length of shape `[batch_size, beam_width]`. Returns: A `Tensor` which is a stacked `TensorArray` of the same size and type as `t` and where beams are sorted in each `Tensor` according to `parent_ids`. """ max_time = parent_ids.shape[0].value or array_ops.shape(parent_ids)[0] batch_size = parent_ids.shape[1].value or array_ops.shape(parent_ids)[1] beam_width = parent_ids.shape[2].value or array_ops.shape(parent_ids)[2] # Generate beam ids that will be reordered by gather_tree. beam_ids = array_ops.expand_dims( array_ops.expand_dims(math_ops.range(beam_width), 0), 0) beam_ids = array_ops.tile(beam_ids, [max_time, batch_size, 1]) mask = array_ops.sequence_mask( sequence_length, maxlen=max_time, dtype=dtypes.int32) mask = array_ops.transpose(mask, perm=[2, 0, 1]) # Use beam_width + 1 to mark the end of beam. masked_beam_ids = (beam_ids * mask) + (1 - mask) * (beam_width + 1) max_sequence_lengths = math_ops.to_int32( math_ops.reduce_max(sequence_length, axis=1)) sorted_beam_ids = beam_search_ops.gather_tree( step_ids=masked_beam_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=beam_width + 1) # For out of range steps, simply copy the same beam. sorted_beam_ids = array_ops.where( math_ops.cast(mask, dtypes.bool), x=sorted_beam_ids, y=beam_ids) # Generate indices for gather_nd. time_ind = array_ops.tile(array_ops.reshape( math_ops.range(max_time), [-1, 1, 1]), [1, batch_size, beam_width]) batch_ind = array_ops.tile(array_ops.reshape( math_ops.range(batch_size), [-1, 1, 1]), [1, max_time, beam_width]) batch_ind = array_ops.transpose(batch_ind, perm=[1, 0, 2]) indices = array_ops.stack([time_ind, batch_ind, sorted_beam_ids], -1) # Gather from a tensor with collapsed additional dimensions. gather_from = t final_shape = array_ops.shape(gather_from) gather_from = array_ops.reshape( gather_from, [max_time, batch_size, beam_width, -1]) ordered = array_ops.gather_nd(gather_from, indices) ordered = array_ops.reshape(ordered, final_shape) return ordered
def testGatherTreeOne(self): # (max_time = 4, batch_size = 1, beams = 3) step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]]) sequence_length = [[3, 3, 3]] expected_result = _transpose_batch_time( [[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) beams = beam_search_ops.gather_tree( step_ids=step_ids, parent_ids=parent_ids, sequence_length=sequence_length) with self.test_session(use_gpu=True): self.assertAllEqual(expected_result, beams.eval())
def testBadParentValuesOnCPU(self): # (batch_size = 1, max_time = 4, beams = 3) # bad parent in beam 1 time 1 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) sequence_length = [[3, 3, 3]] with ops.device("/cpu:0"): beams = beam_search_ops.gather_tree( step_ids=step_ids, parent_ids=parent_ids, sequence_length=sequence_length) with self.test_session(): with self.assertRaisesOpError( r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"): _ = beams.eval()
def testGatherTreeOne(self): # (max_time = 4, batch_size = 1, beams = 3) end_token = 10 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]]) max_sequence_lengths = [3] expected_result = _transpose_batch_time([[[2, 2, 2], [6, 5, 6], [7, 8, 9], [10, 10, 10]]]) beams = beam_search_ops.gather_tree( step_ids=step_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=end_token) with self.cached_session(use_gpu=True): self.assertAllEqual(expected_result, self.evaluate(beams))
def testGatherTreeBatch(self): batch_size = 10 beam_width = 15 max_time = 8 max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0] end_token = 5 with self.cached_session(use_gpu=True): step_ids = np.random.randint(0, high=end_token + 1, size=(max_time, batch_size, beam_width)) parent_ids = np.random.randint(0, high=beam_width - 1, size=(max_time, batch_size, beam_width)) beams = beam_search_ops.gather_tree( step_ids=step_ids.astype(np.int32), parent_ids=parent_ids.astype(np.int32), max_sequence_lengths=max_sequence_lengths, end_token=end_token) self.assertEqual((max_time, batch_size, beam_width), beams.shape) beams_value = self.evaluate(beams) for b in range(batch_size): # Past max_sequence_lengths[b], we emit all end tokens. b_value = beams_value[max_sequence_lengths[b]:, b, :] self.assertAllClose(b_value, end_token * np.ones_like(b_value)) for batch, beam in itertools.product(range(batch_size), range(beam_width)): v = np.squeeze(beams_value[:, batch, beam]) if end_token in v: found_bad = np.where(v == -1)[0] self.assertEqual(0, len(found_bad)) found = np.where(v == end_token)[0] found = found[0] # First occurrence of end_token. # If an end_token is found, everything before it should be a # valid id and everything after it should be -1. if found > 0: self.assertAllEqual( v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool)) self.assertAllClose( v[found + 1:], end_token * np.ones_like(v[found + 1:]))
def testBadParentValuesOnCPU(self): # (batch_size = 1, max_time = 4, beams = 3) # bad parent in beam 1 time 1 step_ids = _transpose_batch_time([[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time([[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) max_sequence_lengths = [3] with ops.device("/cpu:0"): beams = beam_search_ops.gather_tree( step_ids=step_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=10) with self.test_session(): with self.assertRaisesOpError( r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"): _ = beams.eval()
def testBadParentValuesOnCPU(self): # (batch_size = 1, max_time = 4, beams = 3) # bad parent in beam 1 time 1 end_token = 10 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) max_sequence_lengths = [3] with ops.device("/cpu:0"): with self.assertRaisesOpError( r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"): beams = beam_search_ops.gather_tree( step_ids=step_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=end_token) self.evaluate(beams)
def testBadParentValuesOnGPU(self): if not test.is_gpu_available(): return # (max_time = 4, batch_size = 1, beams = 3) # bad parent in beam 1 time 1; appears as a negative index at time 0 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) sequence_length = [3] expected_result = _transpose_batch_time( [[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) with ops.device("/gpu:0"): beams = beam_search_ops.gather_tree( step_ids=step_ids, parent_ids=parent_ids, sequence_length=sequence_length) with self.test_session(use_gpu=True): self.assertAllEqual(expected_result, beams.eval())
def finalize(self, outputs, final_state, sequence_lengths): """Finalize and return the predicted_ids. Args: outputs: An instance of BeamSearchDecoderOutput. final_state: An instance of BeamSearchDecoderState. Passed through to the output. sequence_lengths: An `int32` tensor shaped `[batch_size, beam_width]`. The sequence lengths determined for each beam during decode. Returns: outputs: An instance of FinalBeamSearchDecoderOutput where the predicted_ids are the result of calling _gather_tree. final_state: The same input instance of BeamSearchDecoderState. """ predicted_ids = beam_search_ops.gather_tree( outputs.predicted_ids, outputs.parent_ids, sequence_length=sequence_lengths) outputs = FinalBeamSearchDecoderOutput( beam_search_decoder_output=outputs, predicted_ids=predicted_ids) return outputs, final_state
def testBadParentValuesOnGPU(self): if not test.is_gpu_available(): return # (max_time = 4, batch_size = 1, beams = 3) # bad parent in beam 1 time 1; appears as a negative index at time 0 step_ids = _transpose_batch_time([[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time([[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) sequence_length = [[3, 3, 3]] expected_result = _transpose_batch_time([[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) with ops.device("/gpu:0"): beams = beam_search_ops.gather_tree( step_ids=step_ids, parent_ids=parent_ids, sequence_length=sequence_length) with self.test_session(use_gpu=True): self.assertAllEqual(expected_result, beams.eval())
def finalize(self, outputs, final_state, sequence_lengths): """Finalize and return the predicted_ids. Args: outputs: An instance of CoverageDecoderOutput. final_state: An instance of CoverageDecoderState. Passed through to the output. sequence_lengths: An `int32` tensor shaped `[batch_size, beam_width]`. The sequence lengths determined for each beam during decode. Returns: outputs: An instance of FinalCoverageDecoderOutput where the predicted_ids are the result of calling _gather_tree. final_state: The same input instance of CoverageDecoderState. """ predicted_ids = beam_search_ops.gather_tree( outputs.predicted_ids, outputs.parent_ids, sequence_length=sequence_lengths) coverage_scores = tf.reduce_sum(tf.log(tf.maximum(final_state.coverages, self._threshold)), axis=-1) coverage_penalty = tf.reduce_sum(tf.log(tf.minimum(final_state.coverages, 1.)), axis=-1) if self._length_penalty_weight != 0.0: coverage_scores = coverage_scores / tf.to_float(tf.shape(final_state.coverages)[-1]) #coverage_scores = coverage_scores / _length_penalty(sequence_lengths=sequence_lengths, penalty_factor=self._length_penalty_weight) coverage_penalty = coverage_penalty / tf.to_float(tf.shape(final_state.coverages)[-1]) #coverage_penalty = coverage_penalty / _length_penalty(sequence_lengths=sequence_lengths, penalty_factor=self._length_penalty_weight) scores = final_state.log_probs / _length_penalty( sequence_lengths=final_state.lengths, penalty_factor=self._length_penalty_weight) coverage_score_weight = self._coverage_score_weight / (1. - self._coverage_score_weight) with tf.control_dependencies([ tf.Assert(tf.equal(tf.reduce_sum(tf.to_int32(tf.is_inf(coverage_score_weight))), 0), [coverage_score_weight]), tf.Assert(tf.equal(tf.reduce_sum(tf.to_int32(tf.is_nan(coverage_score_weight))), 0), [coverage_score_weight]) ]): scores = scores + \ coverage_score_weight * coverage_scores + \ self._coverage_penalty_weight * coverage_penalty _, order = tf.nn.top_k(scores, self._beam_width) order = tf.reshape(order, [-1]) predicted_ids = tf.gather(predicted_ids, order, axis=-1) outputs = FinalCoverageDecoderOutput( beam_search_decoder_output=outputs, predicted_ids=predicted_ids) return outputs, final_state
def finalize(self, outputs, final_state, sequence_lengths): """Finalize and return the predicted_ids. Args: outputs: An instance of BeamSearchDecoderOutput. final_state: An instance of BeamSearchDecoderState. Passed through to the output. sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`. The sequence lengths determined for each beam during decode. Returns: outputs: An instance of FinalBeamSearchDecoderOutput where the predicted_ids are the result of calling _gather_tree. final_state: The same input instance of BeamSearchDecoderState. """ predicted_ids = beam_search_ops.gather_tree( outputs.predicted_ids, outputs.parent_ids, sequence_length=sequence_lengths) outputs = FinalBeamSearchDecoderOutput( beam_search_decoder_output=outputs, predicted_ids=predicted_ids) return outputs, final_state
def testGatherTreeBatch(self): batch_size = 10 beam_width = 15 max_time = 8 max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0] end_token = 5 with self.test_session(use_gpu=True): step_ids = np.random.randint( 0, high=end_token + 1, size=(max_time, batch_size, beam_width)) parent_ids = np.random.randint( 0, high=beam_width - 1, size=(max_time, batch_size, beam_width)) beams = beam_search_ops.gather_tree( step_ids=step_ids.astype(np.int32), parent_ids=parent_ids.astype(np.int32), max_sequence_lengths=max_sequence_lengths, end_token=end_token) self.assertEqual((max_time, batch_size, beam_width), beams.shape) beams_value = beams.eval() for b in range(batch_size): # Past max_sequence_lengths[b], we emit all end tokens. b_value = beams_value[max_sequence_lengths[b]:, b, :] self.assertAllClose(b_value, end_token * np.ones_like(b_value)) for batch, beam in itertools.product( range(batch_size), range(beam_width)): v = np.squeeze(beams_value[:, batch, beam]) if end_token in v: found_bad = np.where(v == -1)[0] self.assertEqual(0, len(found_bad)) found = np.where(v == end_token)[0] found = found[0] # First occurrence of end_token. # If an end_token is found, everything before it should be a # valid id and everything after it should be -1. if found > 0: self.assertAllEqual( v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool)) self.assertAllClose(v[found + 1:], end_token * np.ones_like(v[found + 1:]))
def testBadParentValuesOnGPU(self): # Only want to run this test on CUDA devices, as gather_tree is not # registered for SYCL devices. if not test.is_gpu_available(cuda_only=True): return # (max_time = 4, batch_size = 1, beams = 3) # bad parent in beam 1 time 1; appears as a negative index at time 0 step_ids = _transpose_batch_time([[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time([[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) max_sequence_lengths = [3] expected_result = _transpose_batch_time([[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) with ops.device("/device:GPU:0"): beams = beam_search_ops.gather_tree( step_ids=step_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=10) with self.test_session(use_gpu=True): self.assertAllEqual(expected_result, beams.eval())
def testBadParentValuesOnGPU(self): # Only want to run this test on CUDA devices, as gather_tree is not # registered for SYCL devices. if not test.is_gpu_available(cuda_only=True): return # (max_time = 4, batch_size = 1, beams = 3) # bad parent in beam 1 time 1; appears as a negative index at time 0 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) max_sequence_lengths = [3] expected_result = _transpose_batch_time( [[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) with ops.device("/device:GPU:0"): beams = beam_search_ops.gather_tree( step_ids=step_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=10) with self.test_session(use_gpu=True): self.assertAllEqual(expected_result, beams.eval())
def finalize(self, outputs, final_state, sequence_lengths): """Finalize and return the predicted_ids. Args: outputs: An instance of BeamSearchDecoderOutput. final_state: An instance of BeamSearchDecoderState. Passed through to the output. sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`. The sequence lengths determined for each beam during decode. **NOTE** These are ignored; the updated sequence lengths are stored in `final_state.lengths`. Returns: outputs: An instance of `FinalBeamSearchDecoderOutput` where the predicted_ids are the result of calling _gather_tree. final_state: The same input instance of `BeamSearchDecoderState`. """ del sequence_lengths # Get max_sequence_length across all beams for each batch. max_sequence_lengths = math_ops.to_int32( math_ops.reduce_max(final_state.lengths, axis=1)) predicted_ids = beam_search_ops.gather_tree( outputs.predicted_ids, outputs.parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=self._end_token) if self._reorder_tensor_arrays: final_state = final_state._replace(cell_state=nest.map_structure( lambda t: self._maybe_sort_array_beams(t, outputs.parent_ids, final_state.lengths), final_state.cell_state)) outputs = FinalConstrainedBeamSearchDecoderOutput( beam_search_decoder_output=outputs, predicted_ids=predicted_ids, scores=outputs.scores) return outputs, final_state
def testGatherTreeBatch(self): # sequence_length is [batch_size, beam_width] = [4, 5] sequence_length = [[0] * 5, [1] * 5, [2] * 5, [3] * 5] with self.test_session(use_gpu=True): # (max_time = 4, batch_size = 4, beam_width = 5) step_ids = _transpose_batch_time( [[[3, 4, 0, 4, 0], [4, 2, 0, 3, 1], [1, 1, 3, 2, 2], [3, 1, 2, 3, 4]], [[3, 4, 0, 4, 0], [4, 2, 0, 3, 1], [1, 1, 3, 2, 2], [3, 1, 2, 3, 4]], [[1, 2, 3, 4, 2], [2, 1, 1, 3, 2], [3, 0, 1, 0, 0], [3, 4, 0, 2, 4]], [[0, 2, 2, 3, 1], [3, 2, 2, 2, 3], [3, 4, 3, 0, 3], [1, 2, 2, 2, 4]]]) parent_ids = _transpose_batch_time( [[[4, 2, 4, 3, 4], [3, 4, 0, 2, 0], [3, 1, 3, 2, 2], [0, 2, 1, 4, 2]], [[4, 2, 4, 3, 4], [3, 4, 0, 2, 0], [3, 1, 3, 2, 2], [0, 2, 1, 4, 2]], [[3, 0, 0, 4, 0], [1, 2, 4, 2, 2], [4, 4, 0, 3, 0], [2, 4, 4, 3, 0]], [[3, 1, 4, 1, 3], [3, 2, 4, 0, 4], [1, 0, 1, 4, 2], [0, 3, 2, 0, 1]]]) expected_beams = _transpose_batch_time( [[[-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1]], [[3, 4, 0, 4, 0], [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1]], [[2, 3, 2, 3, 3], [2, 1, 1, 3, 2], [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1]], [[2, 3, 2, 1, 1], [2, 3, 2, 3, 2], [3, 4, 3, 0, 3], [-1, -1, -1, -1, -1]]]) beams = beam_search_ops.gather_tree( step_ids=step_ids, parent_ids=parent_ids, sequence_length=sequence_length) self.assertAllEqual(expected_beams, beams.eval())
def gather_tree_from_array(t, parent_ids, sequence_length): """Calculates the full beams for `TensorArray`s. Args: t: A stacked `TensorArray` of size `max_time` that contains `Tensor`s of shape `[batch_size, beam_width, s]` or `[batch_size * beam_width, s]` where `s` is the depth shape. parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`. sequence_length: The sequence length of shape `[batch_size, beam_width]`. Returns: A `Tensor` which is a stacked `TensorArray` of the same size and type as `t` and where beams are sorted in each `Tensor` according to `parent_ids`. """ #print("gather tree") max_time = parent_ids.shape[0].value or array_ops.shape(parent_ids)[0] batch_size = parent_ids.shape[1].value or array_ops.shape(parent_ids)[1] beam_width = parent_ids.shape[2].value or array_ops.shape(parent_ids)[2] # Generate beam ids that will be reordered by gather_tree. beam_ids = array_ops.expand_dims( array_ops.expand_dims(math_ops.range(beam_width), 0), 0) beam_ids = array_ops.tile(beam_ids, [max_time, batch_size, 1]) mask = array_ops.sequence_mask(sequence_length, maxlen=max_time, dtype=dtypes.int32) mask = array_ops.transpose(mask, perm=[2, 0, 1]) # Use beam_width + 1 to mark the end of beam. masked_beam_ids = (beam_ids * mask) + (1 - mask) * (beam_width + 1) max_sequence_lengths = math_ops.to_int32( math_ops.reduce_max(sequence_length, axis=1)) sorted_beam_ids = beam_search_ops.gather_tree( step_ids=masked_beam_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=beam_width + 1) # For out of range steps, simply copy the same beam. sorted_beam_ids = array_ops.where(math_ops.cast(mask, dtypes.bool), x=sorted_beam_ids, y=beam_ids) # Generate indices for gather_nd. time_ind = array_ops.tile( array_ops.reshape(math_ops.range(max_time), [-1, 1, 1]), [1, batch_size, beam_width]) batch_ind = array_ops.tile( array_ops.reshape(math_ops.range(batch_size), [-1, 1, 1]), [1, max_time, beam_width]) batch_ind = array_ops.transpose(batch_ind, perm=[1, 0, 2]) indices = array_ops.stack([time_ind, batch_ind, sorted_beam_ids], -1) # Gather from a tensor with collapsed additional dimensions. gather_from = t final_shape = array_ops.shape(gather_from) gather_from = array_ops.reshape(gather_from, [max_time, batch_size, beam_width, -1]) ordered = array_ops.gather_nd(gather_from, indices) ordered = array_ops.reshape(ordered, final_shape) return ordered
def body(time, outputs_ta, state, inputs, finished, sequence_lengths, hypotheses, input_ids, scores, base_index): """Internal while_loop body. Args: time: scalar int32 tensor. outputs_ta: structure of TensorArray. state: (structure of) state tensors and TensorArrays. inputs: (structure of) input tensors. finished: bool tensor (keeping track of what's finished). sequence_lengths: int32 tensor (keeping track of time of finish). hypotheses: structure of TensorArray (stores hypotheses so far). input_ids: structure of TensorArray. scores: structure of TensorArray. base_index: int32 tensor (keeping track of size of the above 3 TensorArrays) Returns: `(time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths, new_hypotheses, new_input_ids, new_scores, new_base)`. ``` """ (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs, state) if decoder.tracks_own_finished: next_finished = decoder_finished else: next_finished = math_ops.logical_or(decoder_finished, finished) next_sequence_lengths = array_ops.where( math_ops.logical_not(next_finished), array_ops.fill(array_ops.shape(sequence_lengths), time + 1 + (not decoder._use_go_tokens)), sequence_lengths) nest.assert_same_structure(state, decoder_state) nest.assert_same_structure(outputs_ta, next_outputs) nest.assert_same_structure(inputs, next_inputs) # Zero out output values past finish if impute_finished: emit = nest.map_structure( lambda out, zero: array_ops.where( next_finished, zero, out), next_outputs, zero_outputs) else: emit = next_outputs # Copy through states past finish def _maybe_copy_state(new, cur): # TensorArrays and scalar states get passed through. if isinstance(cur, tensor_array_ops.TensorArray): pass_through = True else: new.set_shape(cur.shape) pass_through = (new.shape.ndims == 0) return new if pass_through else array_ops.where( finished, cur, new) outputs_ta = nest.map_structure( lambda ta, out: ta.write(time + (not decoder._use_go_tokens), out), outputs_ta, emit) # Extract hypotheses, scores for reference outputs_so_far = nest.map_structure(lambda ta: ta.stack(), outputs_ta) parent_ids = outputs_so_far.parent_ids hypotheses_so_far = outputs_so_far.predicted_ids forward_scores = next_outputs.scores sl = tf.ones([decoder.batch_size], tf.int32) * \ (time + 1 + (not decoder._use_go_tokens)) hypotheses_so_far = beam_search_ops.gather_tree( hypotheses_so_far, parent_ids, max_sequence_lengths=sl, end_token=decoder._end_token) # Add repetition penalty if repetition != 0: def unique_counter(x): return tf.cast(tf.size(tf.unique(x)[0]), tf.float32) def wrapper_rep_penalty_function(sentence): def first_time_penalty(): total_unique_words = unique_counter(sentence) return tf.log(total_unique_words) def generic_penalty(): sentence_length = tf.shape(sentence)[0] total_unique_words = unique_counter(sentence) unique_words_before = unique_counter( sentence[:sentence_length - 1]) return (tf.log(total_unique_words) - tf.log(unique_words_before)) return repetition * tf.cond(math_ops.equal( tf.shape(sentence)[0], 1), true_fn=first_time_penalty, false_fn=generic_penalty) # Use reshaped hypotheses; calculate penalty per beam per batch transposed_hypotheses = tf.transpose(hypotheses_so_far, [2, 1, 0]) repetition_penalty = tf.map_fn(lambda x: tf.map_fn( wrapper_rep_penalty_function, x, dtype=tf.float32), transposed_hypotheses, dtype=tf.float32) repetition_penalty = tf.transpose(repetition_penalty, [1, 0]) forward_scores += repetition_penalty # Add repetition penalty hypothesis scores decoder_state = BeamSearchDecoderState( cell_state=decoder_state.cell_state, log_probs=decoder_state.log_probs + repetition_penalty, finished=decoder_state.finished, lengths=decoder_state.lengths, accumulated_attention_probs=decoder_state. accumulated_attention_probs) if impute_finished: next_state = nest.map_structure(_maybe_copy_state, decoder_state, state) else: next_state = decoder_state finished_this_beam = math_ops.logical_and( math_ops.logical_not(finished), decoder_finished) # Make sure number of outputs is never zero finished_this_beam = tf.cond( math_ops.logical_and( math_ops.logical_and( tf.equal(hypotheses.size(), 0), tf.equal(tf.size(tf.where(finished_this_beam)), 0)), tf.equal(time, maximum_iterations - 1)), true_fn=lambda: tf.cast(tf.ones_like(finished_this_beam), dtype=tf.bool), false_fn=lambda: finished_this_beam) def prepare_hypotheses_for_ta(): finished_beams = tf.where(finished_this_beam) hypotheses_for_ta = tf.boolean_mask( tf.transpose(hypotheses_so_far, [1, 2, 0]), finished_this_beam) # Pad hypotheses with EOS token hypotheses_for_ta = tf.pad(hypotheses_for_ta, [[0, 0], [ 0, maximum_iterations + (not decoder._use_go_tokens) - tf.shape(hypotheses_for_ta)[-1] ]], constant_values=decoder._end_token) input_query_id = tf.expand_dims(finished_beams[:, 0], 1) scores_forward = tf.expand_dims( tf.boolean_mask(forward_scores, finished_this_beam), 1) def inner_cond(index, base, hyp_ta, ind_ta, score_ta, hypos, input_ids, forward_scores): # Populate TA with given elements AND do not consider blank responses return math_ops.logical_and( math_ops.less(index, tf.shape(hypos)[0]), math_ops.greater(time, 0 - (not decoder._use_go_tokens))) def inner_body(index, base, hyp_ta, ind_ta, score_ta, hypos, input_ids, forward_scores): new_hyp_ta = nest.map_structure( lambda ta, out: ta.write(base, out), hyp_ta, hypos[index]) new_ind_ta = nest.map_structure( lambda ta, out: ta.write(base, out), ind_ta, input_ids[index]) # Remove repetition penalty from stored score, use as a feature for later re-reranking forward_scores_store = forward_scores[index] if repetition != 0: forward_scores_store -= repetition * \ unique_counter(hypos[index]) # Normalize finished scores by their length new_scores_ta = nest.map_structure( lambda ta, out: ta.write(base, out), score_ta, forward_scores_store / tf.cast( tf.count_nonzero(hypos[index] - decoder._end_token), tf.float32)) return (index + 1, base + 1, new_hyp_ta, new_ind_ta, new_scores_ta, hypos, input_ids, forward_scores) # Add multiple hypotheses (and related information) to TensorArray using a while_loop inner_result = tf.while_loop( inner_cond, inner_body, loop_vars=(tf.constant(0), base_index, hypotheses, input_ids, scores, hypotheses_for_ta, input_query_id, scores_forward), parallel_iterations=parallel_iterations, swap_memory=swap_memory) return inner_result[1], inner_result[2], inner_result[ 3], inner_result[4] # In case finished is not True for any beams new_base, new_hypotheses, new_input_ids, new_scores = tf.cond( math_ops.greater(tf.count_nonzero(finished_this_beam), 0), true_fn=prepare_hypotheses_for_ta, false_fn=lambda: (base_index, hypotheses, input_ids, scores)) return (time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths, new_hypotheses, new_input_ids, new_scores, new_base)
def finalize(self, outputs, final_state, sequence_lengths): """Finalize and return the predicted_ids. Args: outputs: An instance of BeamSearchDecoderOutput. final_state: An instance of BeamSearchDecoderState. Passed through to the output. sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`. The sequence lengths determined for each beam during decode. **NOTE** These are ignored; the updated sequence lengths are stored in `final_state.lengths`. Returns: outputs: An instance of `FinalBeamSearchDecoderOutput` where the predicted_ids are the result of calling _gather_tree. final_state: The same input instance of `BeamSearchDecoderState`. """ del sequence_lengths self._decoding_iterations_remaining -= 1 # Decrease counter. # Get max_sequence_length across all beams for each batch. max_sequence_lengths = math_ops.to_int32( math_ops.reduce_max(final_state.lengths, axis=1)) predicted_ids = beam_search_ops.gather_tree( outputs.predicted_ids, outputs.parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=self._end_token) if self._reorder_tensor_arrays: # pylint: disable=g-long-lambda # pylint: disable=line-too-long final_state = final_state._replace(cell_state=nest.map_structure( lambda t: self._maybe_sort_array_beams(t, outputs.parent_ids, final_state.lengths), final_state.cell_state)) # pylint: enable=g-long-lambda # pylint: enable=line-too-long if self._decoding_iterations_remaining >= 1: # Transpose to [batch_size, time, beam_width] new_forbidden_tokens = tf.transpose(predicted_ids, perm=[1, 0, 2]) # Reshape to [batch_size, time * beam_width] new_forbidden_tokens = tf.reshape( new_forbidden_tokens, shape=[tf.shape(new_forbidden_tokens)[0], -1]) if self._forbidden_tokens is not None: self._forbidden_tokens = tf.concat( [self._forbidden_tokens, new_forbidden_tokens], axis=1) else: self._forbidden_tokens = new_forbidden_tokens new_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( self, maximum_iterations=self._maximum_iterations, output_time_major=True, swap_memory=True, scope=self._decoder_scope) all_scores = tf.concat([ outputs.scores, new_outputs.beam_search_decoder_output.scores ], axis=2) all_predicted_ids = tf.concat([ outputs.predicted_ids, new_outputs.beam_search_decoder_output.predicted_ids ], axis=2) all_parent_ids = tf.concat([ outputs.parent_ids, new_outputs.beam_search_decoder_output.parent_ids ], axis=2) outputs = beam_search_decoder.BeamSearchDecoderOutput( scores=all_scores, predicted_ids=all_predicted_ids, parent_ids=all_parent_ids) # Append eos token ids in case predicted_ids is shorter than new # predicted_ids, and vice-versa. predicted_ids = pad(x=predicted_ids, max_size=tf.shape( new_outputs.predicted_ids)[0], value=self._end_token) new_predicted_ids = pad(x=new_outputs.predicted_ids, max_size=tf.shape(predicted_ids)[0], value=self._end_token) predicted_ids = tf.concat([predicted_ids, new_predicted_ids], axis=2) outputs = beam_search_decoder.FinalBeamSearchDecoderOutput( beam_search_decoder_output=outputs, predicted_ids=predicted_ids) return outputs, final_state