def finalize(self, outputs, final_state, sequence_lengths):
    """Finalize and return the predicted_ids.

    Args:
      outputs: An instance of BeamSearchDecoderOutput.
      final_state: An instance of BeamSearchDecoderState. Passed through to the
        output.
      sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`.
        The sequence lengths determined for each beam during decode.
        **NOTE** These are ignored; the updated sequence lengths are stored in
        `final_state.lengths`.

    Returns:
      outputs: An instance of `FinalBeamSearchDecoderOutput` where the
        predicted_ids are the result of calling _gather_tree.
      final_state: The same input instance of `BeamSearchDecoderState`.
    """
    del sequence_lengths
    # Get max_sequence_length across all beams for each batch.
    max_sequence_lengths = math_ops.to_int32(
        math_ops.reduce_max(final_state.lengths, axis=1))
    predicted_ids = beam_search_ops.gather_tree(
        outputs.predicted_ids,
        outputs.parent_ids,
        max_sequence_lengths=max_sequence_lengths,
        end_token=self._end_token)
    outputs = FinalBeamSearchDecoderOutput(
        beam_search_decoder_output=outputs, predicted_ids=predicted_ids)
    return outputs, final_state
  def test_gather_tree(self):
    # (max_time = 3, batch_size = 2, beam_width = 3)

    # create (batch_size, max_time, beam_width) matrix and transpose it
    predicted_ids = np.array(
        [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[2, 3, 4], [5, 6, 7], [8, 9, 10]]],
        dtype=np.int32).transpose([1, 0, 2])
    parent_ids = np.array(
        [[[0, 0, 0], [0, 1, 1], [2, 1, 2]], [[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
        dtype=np.int32).transpose([1, 0, 2])

    # sequence_lengths is shaped (batch_size = 3)
    max_sequence_lengths = [3, 3]

    expected_result = np.array([[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
                                [[2, 4, 4], [7, 6, 6],
                                 [8, 9, 10]]]).transpose([1, 0, 2])

    res = beam_search_ops.gather_tree(
        predicted_ids,
        parent_ids,
        max_sequence_lengths=max_sequence_lengths,
        end_token=11)

    with self.cached_session() as sess:
      res_ = sess.run(res)

    self.assertAllEqual(expected_result, res_)
def gather_tree_from_array(t, parent_ids, sequence_length):
  """Calculates the full beams for `TensorArray`s.

  Args:
    t: A stacked `TensorArray` of size `max_time` that contains `Tensor`s of
      shape `[batch_size, beam_width, s]` or `[batch_size * beam_width, s]`
      where `s` is the depth shape.
    parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`.
    sequence_length: The sequence length of shape `[batch_size, beam_width]`.

  Returns:
    A `Tensor` which is a stacked `TensorArray` of the same size and type as
    `t` and where beams are sorted in each `Tensor` according to `parent_ids`.
  """
  max_time = parent_ids.shape[0].value or array_ops.shape(parent_ids)[0]
  batch_size = parent_ids.shape[1].value or array_ops.shape(parent_ids)[1]
  beam_width = parent_ids.shape[2].value or array_ops.shape(parent_ids)[2]

  # Generate beam ids that will be reordered by gather_tree.
  beam_ids = array_ops.expand_dims(
      array_ops.expand_dims(math_ops.range(beam_width), 0), 0)
  beam_ids = array_ops.tile(beam_ids, [max_time, batch_size, 1])

  mask = array_ops.sequence_mask(
      sequence_length, maxlen=max_time, dtype=dtypes.int32)
  mask = array_ops.transpose(mask, perm=[2, 0, 1])

  # Use beam_width + 1 to mark the end of beam.
  masked_beam_ids = (beam_ids * mask) + (1 - mask) * (beam_width + 1)

  max_sequence_lengths = math_ops.to_int32(
      math_ops.reduce_max(sequence_length, axis=1))
  sorted_beam_ids = beam_search_ops.gather_tree(
      step_ids=masked_beam_ids,
      parent_ids=parent_ids,
      max_sequence_lengths=max_sequence_lengths,
      end_token=beam_width + 1)

  # For out of range steps, simply copy the same beam.
  sorted_beam_ids = array_ops.where(
      math_ops.cast(mask, dtypes.bool), x=sorted_beam_ids, y=beam_ids)

  # Generate indices for gather_nd.
  time_ind = array_ops.tile(array_ops.reshape(
      math_ops.range(max_time), [-1, 1, 1]), [1, batch_size, beam_width])
  batch_ind = array_ops.tile(array_ops.reshape(
      math_ops.range(batch_size), [-1, 1, 1]), [1, max_time, beam_width])
  batch_ind = array_ops.transpose(batch_ind, perm=[1, 0, 2])
  indices = array_ops.stack([time_ind, batch_ind, sorted_beam_ids], -1)

  # Gather from a tensor with collapsed additional dimensions.
  gather_from = t
  final_shape = array_ops.shape(gather_from)
  gather_from = array_ops.reshape(
      gather_from, [max_time, batch_size, beam_width, -1])
  ordered = array_ops.gather_nd(gather_from, indices)
  ordered = array_ops.reshape(ordered, final_shape)

  return ordered
 def testGatherTreeOne(self):
   # (max_time = 4, batch_size = 1, beams = 3)
   step_ids = _transpose_batch_time(
       [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
   parent_ids = _transpose_batch_time(
       [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]])
   sequence_length = [[3, 3, 3]]
   expected_result = _transpose_batch_time(
       [[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
   beams = beam_search_ops.gather_tree(
       step_ids=step_ids, parent_ids=parent_ids,
       sequence_length=sequence_length)
   with self.test_session(use_gpu=True):
     self.assertAllEqual(expected_result, beams.eval())
 def testBadParentValuesOnCPU(self):
   # (batch_size = 1, max_time = 4, beams = 3)
   # bad parent in beam 1 time 1
   step_ids = _transpose_batch_time(
       [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
   parent_ids = _transpose_batch_time(
       [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
   sequence_length = [[3, 3, 3]]
   with ops.device("/cpu:0"):
     beams = beam_search_ops.gather_tree(
         step_ids=step_ids, parent_ids=parent_ids,
         sequence_length=sequence_length)
   with self.test_session():
     with self.assertRaisesOpError(
         r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
       _ = beams.eval()
 def testGatherTreeOne(self):
   # (max_time = 4, batch_size = 1, beams = 3)
   end_token = 10
   step_ids = _transpose_batch_time(
       [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
   parent_ids = _transpose_batch_time(
       [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]])
   max_sequence_lengths = [3]
   expected_result = _transpose_batch_time([[[2, 2, 2], [6, 5, 6], [7, 8, 9],
                                             [10, 10, 10]]])
   beams = beam_search_ops.gather_tree(
       step_ids=step_ids,
       parent_ids=parent_ids,
       max_sequence_lengths=max_sequence_lengths,
       end_token=end_token)
   with self.cached_session(use_gpu=True):
     self.assertAllEqual(expected_result, self.evaluate(beams))
    def testGatherTreeBatch(self):
        batch_size = 10
        beam_width = 15
        max_time = 8
        max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0]
        end_token = 5

        with self.cached_session(use_gpu=True):
            step_ids = np.random.randint(0,
                                         high=end_token + 1,
                                         size=(max_time, batch_size,
                                               beam_width))
            parent_ids = np.random.randint(0,
                                           high=beam_width - 1,
                                           size=(max_time, batch_size,
                                                 beam_width))

            beams = beam_search_ops.gather_tree(
                step_ids=step_ids.astype(np.int32),
                parent_ids=parent_ids.astype(np.int32),
                max_sequence_lengths=max_sequence_lengths,
                end_token=end_token)

            self.assertEqual((max_time, batch_size, beam_width), beams.shape)
            beams_value = self.evaluate(beams)
            for b in range(batch_size):
                # Past max_sequence_lengths[b], we emit all end tokens.
                b_value = beams_value[max_sequence_lengths[b]:, b, :]
                self.assertAllClose(b_value, end_token * np.ones_like(b_value))
            for batch, beam in itertools.product(range(batch_size),
                                                 range(beam_width)):
                v = np.squeeze(beams_value[:, batch, beam])
                if end_token in v:
                    found_bad = np.where(v == -1)[0]
                    self.assertEqual(0, len(found_bad))
                    found = np.where(v == end_token)[0]
                    found = found[0]  # First occurrence of end_token.
                    # If an end_token is found, everything before it should be a
                    # valid id and everything after it should be -1.
                    if found > 0:
                        self.assertAllEqual(
                            v[:found - 1] >= 0,
                            np.ones_like(v[:found - 1], dtype=bool))
                    self.assertAllClose(
                        v[found + 1:], end_token * np.ones_like(v[found + 1:]))
Beispiel #8
0
 def testBadParentValuesOnCPU(self):
     # (batch_size = 1, max_time = 4, beams = 3)
     # bad parent in beam 1 time 1
     step_ids = _transpose_batch_time([[[1, 2, 3], [4, 5, 6], [7, 8, 9],
                                        [-1, -1, -1]]])
     parent_ids = _transpose_batch_time([[[0, 0, 0], [0, -1, 1], [2, 1, 2],
                                          [-1, -1, -1]]])
     max_sequence_lengths = [3]
     with ops.device("/cpu:0"):
         beams = beam_search_ops.gather_tree(
             step_ids=step_ids,
             parent_ids=parent_ids,
             max_sequence_lengths=max_sequence_lengths,
             end_token=10)
     with self.test_session():
         with self.assertRaisesOpError(
                 r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
             _ = beams.eval()
 def testBadParentValuesOnCPU(self):
   # (batch_size = 1, max_time = 4, beams = 3)
   # bad parent in beam 1 time 1
   end_token = 10
   step_ids = _transpose_batch_time(
       [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
   parent_ids = _transpose_batch_time(
       [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
   max_sequence_lengths = [3]
   with ops.device("/cpu:0"):
     with self.assertRaisesOpError(
         r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
       beams = beam_search_ops.gather_tree(
           step_ids=step_ids,
           parent_ids=parent_ids,
           max_sequence_lengths=max_sequence_lengths,
           end_token=end_token)
       self.evaluate(beams)
 def testBadParentValuesOnGPU(self):
   if not test.is_gpu_available():
     return
   # (max_time = 4, batch_size = 1, beams = 3)
   # bad parent in beam 1 time 1; appears as a negative index at time 0
   step_ids = _transpose_batch_time(
       [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
   parent_ids = _transpose_batch_time(
       [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
   sequence_length = [3]
   expected_result = _transpose_batch_time(
       [[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
   with ops.device("/gpu:0"):
     beams = beam_search_ops.gather_tree(
         step_ids=step_ids, parent_ids=parent_ids,
         sequence_length=sequence_length)
   with self.test_session(use_gpu=True):
     self.assertAllEqual(expected_result, beams.eval())
Beispiel #11
0
 def finalize(self, outputs, final_state, sequence_lengths):
   """Finalize and return the predicted_ids.
   Args:
     outputs: An instance of BeamSearchDecoderOutput.
     final_state: An instance of BeamSearchDecoderState. Passed through to the
       output.
     sequence_lengths: An `int32` tensor shaped `[batch_size, beam_width]`.
       The sequence lengths determined for each beam during decode.
   Returns:
     outputs: An instance of FinalBeamSearchDecoderOutput where the
       predicted_ids are the result of calling _gather_tree.
     final_state: The same input instance of BeamSearchDecoderState.
   """
   predicted_ids = beam_search_ops.gather_tree(
       outputs.predicted_ids, outputs.parent_ids,
       sequence_length=sequence_lengths)
   outputs = FinalBeamSearchDecoderOutput(
       beam_search_decoder_output=outputs, predicted_ids=predicted_ids)
   return outputs, final_state
 def testBadParentValuesOnGPU(self):
     if not test.is_gpu_available():
         return
     # (max_time = 4, batch_size = 1, beams = 3)
     # bad parent in beam 1 time 1; appears as a negative index at time 0
     step_ids = _transpose_batch_time([[[1, 2, 3], [4, 5, 6], [7, 8, 9],
                                        [-1, -1, -1]]])
     parent_ids = _transpose_batch_time([[[0, 0, 0], [0, -1, 1], [2, 1, 2],
                                          [-1, -1, -1]]])
     sequence_length = [[3, 3, 3]]
     expected_result = _transpose_batch_time([[[2, -1, 2], [6, 5, 6],
                                               [7, 8, 9], [-1, -1, -1]]])
     with ops.device("/gpu:0"):
         beams = beam_search_ops.gather_tree(
             step_ids=step_ids,
             parent_ids=parent_ids,
             sequence_length=sequence_length)
     with self.test_session(use_gpu=True):
         self.assertAllEqual(expected_result, beams.eval())
Beispiel #13
0
    def finalize(self, outputs, final_state, sequence_lengths):
        """Finalize and return the predicted_ids.

        Args:
          outputs: An instance of CoverageDecoderOutput.
          final_state: An instance of CoverageDecoderState. Passed through to the
            output.
          sequence_lengths: An `int32` tensor shaped `[batch_size, beam_width]`.
            The sequence lengths determined for each beam during decode.

        Returns:
          outputs: An instance of FinalCoverageDecoderOutput where the
            predicted_ids are the result of calling _gather_tree.
          final_state: The same input instance of CoverageDecoderState.
        """
        predicted_ids = beam_search_ops.gather_tree(
            outputs.predicted_ids, outputs.parent_ids,
            sequence_length=sequence_lengths)

        coverage_scores = tf.reduce_sum(tf.log(tf.maximum(final_state.coverages, self._threshold)), axis=-1)
        coverage_penalty = tf.reduce_sum(tf.log(tf.minimum(final_state.coverages, 1.)), axis=-1)
        if self._length_penalty_weight != 0.0:
            coverage_scores = coverage_scores / tf.to_float(tf.shape(final_state.coverages)[-1])
            #coverage_scores = coverage_scores / _length_penalty(sequence_lengths=sequence_lengths, penalty_factor=self._length_penalty_weight)
            coverage_penalty = coverage_penalty / tf.to_float(tf.shape(final_state.coverages)[-1])
            #coverage_penalty = coverage_penalty / _length_penalty(sequence_lengths=sequence_lengths, penalty_factor=self._length_penalty_weight)
        scores = final_state.log_probs / _length_penalty(
            sequence_lengths=final_state.lengths, penalty_factor=self._length_penalty_weight)
        coverage_score_weight = self._coverage_score_weight / (1. - self._coverage_score_weight)
        with tf.control_dependencies([
            tf.Assert(tf.equal(tf.reduce_sum(tf.to_int32(tf.is_inf(coverage_score_weight))), 0), [coverage_score_weight]),
            tf.Assert(tf.equal(tf.reduce_sum(tf.to_int32(tf.is_nan(coverage_score_weight))), 0), [coverage_score_weight])
        ]):
            scores = scores + \
                     coverage_score_weight * coverage_scores + \
                     self._coverage_penalty_weight * coverage_penalty
        _, order = tf.nn.top_k(scores, self._beam_width)
        order = tf.reshape(order, [-1])
        predicted_ids = tf.gather(predicted_ids, order, axis=-1)

        outputs = FinalCoverageDecoderOutput(
            beam_search_decoder_output=outputs, predicted_ids=predicted_ids)
        return outputs, final_state
  def finalize(self, outputs, final_state, sequence_lengths):
    """Finalize and return the predicted_ids.

    Args:
      outputs: An instance of BeamSearchDecoderOutput.
      final_state: An instance of BeamSearchDecoderState. Passed through to the
        output.
      sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`.
        The sequence lengths determined for each beam during decode.

    Returns:
      outputs: An instance of FinalBeamSearchDecoderOutput where the
        predicted_ids are the result of calling _gather_tree.
      final_state: The same input instance of BeamSearchDecoderState.
    """
    predicted_ids = beam_search_ops.gather_tree(
        outputs.predicted_ids, outputs.parent_ids,
        sequence_length=sequence_lengths)
    outputs = FinalBeamSearchDecoderOutput(
        beam_search_decoder_output=outputs, predicted_ids=predicted_ids)
    return outputs, final_state
  def testGatherTreeBatch(self):
    batch_size = 10
    beam_width = 15
    max_time = 8
    max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0]
    end_token = 5

    with self.test_session(use_gpu=True):
      step_ids = np.random.randint(
          0, high=end_token + 1, size=(max_time, batch_size, beam_width))
      parent_ids = np.random.randint(
          0, high=beam_width - 1, size=(max_time, batch_size, beam_width))

      beams = beam_search_ops.gather_tree(
          step_ids=step_ids.astype(np.int32),
          parent_ids=parent_ids.astype(np.int32),
          max_sequence_lengths=max_sequence_lengths,
          end_token=end_token)

      self.assertEqual((max_time, batch_size, beam_width), beams.shape)
      beams_value = beams.eval()
      for b in range(batch_size):
        # Past max_sequence_lengths[b], we emit all end tokens.
        b_value = beams_value[max_sequence_lengths[b]:, b, :]
        self.assertAllClose(b_value, end_token * np.ones_like(b_value))
      for batch, beam in itertools.product(
          range(batch_size), range(beam_width)):
        v = np.squeeze(beams_value[:, batch, beam])
        if end_token in v:
          found_bad = np.where(v == -1)[0]
          self.assertEqual(0, len(found_bad))
          found = np.where(v == end_token)[0]
          found = found[0]  # First occurrence of end_token.
          # If an end_token is found, everything before it should be a
          # valid id and everything after it should be -1.
          if found > 0:
            self.assertAllEqual(
                v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool))
          self.assertAllClose(v[found + 1:],
                              end_token * np.ones_like(v[found + 1:]))
Beispiel #16
0
 def testBadParentValuesOnGPU(self):
     # Only want to run this test on CUDA devices, as gather_tree is not
     # registered for SYCL devices.
     if not test.is_gpu_available(cuda_only=True):
         return
     # (max_time = 4, batch_size = 1, beams = 3)
     # bad parent in beam 1 time 1; appears as a negative index at time 0
     step_ids = _transpose_batch_time([[[1, 2, 3], [4, 5, 6], [7, 8, 9],
                                        [-1, -1, -1]]])
     parent_ids = _transpose_batch_time([[[0, 0, 0], [0, -1, 1], [2, 1, 2],
                                          [-1, -1, -1]]])
     max_sequence_lengths = [3]
     expected_result = _transpose_batch_time([[[2, -1, 2], [6, 5, 6],
                                               [7, 8, 9], [-1, -1, -1]]])
     with ops.device("/device:GPU:0"):
         beams = beam_search_ops.gather_tree(
             step_ids=step_ids,
             parent_ids=parent_ids,
             max_sequence_lengths=max_sequence_lengths,
             end_token=10)
     with self.test_session(use_gpu=True):
         self.assertAllEqual(expected_result, beams.eval())
 def testBadParentValuesOnGPU(self):
   # Only want to run this test on CUDA devices, as gather_tree is not
   # registered for SYCL devices.
   if not test.is_gpu_available(cuda_only=True):
     return
   # (max_time = 4, batch_size = 1, beams = 3)
   # bad parent in beam 1 time 1; appears as a negative index at time 0
   step_ids = _transpose_batch_time(
       [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
   parent_ids = _transpose_batch_time(
       [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
   max_sequence_lengths = [3]
   expected_result = _transpose_batch_time(
       [[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
   with ops.device("/device:GPU:0"):
     beams = beam_search_ops.gather_tree(
         step_ids=step_ids,
         parent_ids=parent_ids,
         max_sequence_lengths=max_sequence_lengths,
         end_token=10)
   with self.test_session(use_gpu=True):
     self.assertAllEqual(expected_result, beams.eval())
    def finalize(self, outputs, final_state, sequence_lengths):
        """Finalize and return the predicted_ids.

    Args:
      outputs: An instance of BeamSearchDecoderOutput.
      final_state: An instance of BeamSearchDecoderState. Passed through to the
        output.
      sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`.
        The sequence lengths determined for each beam during decode.
        **NOTE** These are ignored; the updated sequence lengths are stored in
        `final_state.lengths`.

    Returns:
      outputs: An instance of `FinalBeamSearchDecoderOutput` where the
        predicted_ids are the result of calling _gather_tree.
      final_state: The same input instance of `BeamSearchDecoderState`.
    """
        del sequence_lengths
        # Get max_sequence_length across all beams for each batch.
        max_sequence_lengths = math_ops.to_int32(
            math_ops.reduce_max(final_state.lengths, axis=1))
        predicted_ids = beam_search_ops.gather_tree(
            outputs.predicted_ids,
            outputs.parent_ids,
            max_sequence_lengths=max_sequence_lengths,
            end_token=self._end_token)
        if self._reorder_tensor_arrays:
            final_state = final_state._replace(cell_state=nest.map_structure(
                lambda t: self._maybe_sort_array_beams(t, outputs.parent_ids,
                                                       final_state.lengths),
                final_state.cell_state))
        outputs = FinalConstrainedBeamSearchDecoderOutput(
            beam_search_decoder_output=outputs,
            predicted_ids=predicted_ids,
            scores=outputs.scores)
        return outputs, final_state
  def testGatherTreeBatch(self):
    # sequence_length is [batch_size, beam_width] = [4, 5]
    sequence_length = [[0] * 5, [1] * 5, [2] * 5, [3] * 5]

    with self.test_session(use_gpu=True):
      # (max_time = 4, batch_size = 4, beam_width = 5)
      step_ids = _transpose_batch_time(
          [[[3, 4, 0, 4, 0],
            [4, 2, 0, 3, 1],
            [1, 1, 3, 2, 2],
            [3, 1, 2, 3, 4]],
           [[3, 4, 0, 4, 0],
            [4, 2, 0, 3, 1],
            [1, 1, 3, 2, 2],
            [3, 1, 2, 3, 4]],
           [[1, 2, 3, 4, 2],
            [2, 1, 1, 3, 2],
            [3, 0, 1, 0, 0],
            [3, 4, 0, 2, 4]],
           [[0, 2, 2, 3, 1],
            [3, 2, 2, 2, 3],
            [3, 4, 3, 0, 3],
            [1, 2, 2, 2, 4]]])
      parent_ids = _transpose_batch_time(
          [[[4, 2, 4, 3, 4],
            [3, 4, 0, 2, 0],
            [3, 1, 3, 2, 2],
            [0, 2, 1, 4, 2]],
           [[4, 2, 4, 3, 4],
            [3, 4, 0, 2, 0],
            [3, 1, 3, 2, 2],
            [0, 2, 1, 4, 2]],
           [[3, 0, 0, 4, 0],
            [1, 2, 4, 2, 2],
            [4, 4, 0, 3, 0],
            [2, 4, 4, 3, 0]],
           [[3, 1, 4, 1, 3],
            [3, 2, 4, 0, 4],
            [1, 0, 1, 4, 2],
            [0, 3, 2, 0, 1]]])
      expected_beams = _transpose_batch_time(
          [[[-1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1]],
           [[3, 4, 0, 4, 0],
            [-1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1]],
           [[2, 3, 2, 3, 3],
            [2, 1, 1, 3, 2],
            [-1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1]],
           [[2, 3, 2, 1, 1],
            [2, 3, 2, 3, 2],
            [3, 4, 3, 0, 3],
            [-1, -1, -1, -1, -1]]])

      beams = beam_search_ops.gather_tree(
          step_ids=step_ids, parent_ids=parent_ids,
          sequence_length=sequence_length)
      self.assertAllEqual(expected_beams, beams.eval())
def gather_tree_from_array(t, parent_ids, sequence_length):
    """Calculates the full beams for `TensorArray`s.

  Args:
    t: A stacked `TensorArray` of size `max_time` that contains `Tensor`s of
      shape `[batch_size, beam_width, s]` or `[batch_size * beam_width, s]`
      where `s` is the depth shape.
    parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`.
    sequence_length: The sequence length of shape `[batch_size, beam_width]`.

  Returns:
    A `Tensor` which is a stacked `TensorArray` of the same size and type as
    `t` and where beams are sorted in each `Tensor` according to `parent_ids`.
  """
    #print("gather tree")
    max_time = parent_ids.shape[0].value or array_ops.shape(parent_ids)[0]
    batch_size = parent_ids.shape[1].value or array_ops.shape(parent_ids)[1]
    beam_width = parent_ids.shape[2].value or array_ops.shape(parent_ids)[2]

    # Generate beam ids that will be reordered by gather_tree.
    beam_ids = array_ops.expand_dims(
        array_ops.expand_dims(math_ops.range(beam_width), 0), 0)
    beam_ids = array_ops.tile(beam_ids, [max_time, batch_size, 1])

    mask = array_ops.sequence_mask(sequence_length,
                                   maxlen=max_time,
                                   dtype=dtypes.int32)
    mask = array_ops.transpose(mask, perm=[2, 0, 1])

    # Use beam_width + 1 to mark the end of beam.
    masked_beam_ids = (beam_ids * mask) + (1 - mask) * (beam_width + 1)

    max_sequence_lengths = math_ops.to_int32(
        math_ops.reduce_max(sequence_length, axis=1))
    sorted_beam_ids = beam_search_ops.gather_tree(
        step_ids=masked_beam_ids,
        parent_ids=parent_ids,
        max_sequence_lengths=max_sequence_lengths,
        end_token=beam_width + 1)

    # For out of range steps, simply copy the same beam.
    sorted_beam_ids = array_ops.where(math_ops.cast(mask, dtypes.bool),
                                      x=sorted_beam_ids,
                                      y=beam_ids)

    # Generate indices for gather_nd.
    time_ind = array_ops.tile(
        array_ops.reshape(math_ops.range(max_time), [-1, 1, 1]),
        [1, batch_size, beam_width])
    batch_ind = array_ops.tile(
        array_ops.reshape(math_ops.range(batch_size), [-1, 1, 1]),
        [1, max_time, beam_width])
    batch_ind = array_ops.transpose(batch_ind, perm=[1, 0, 2])
    indices = array_ops.stack([time_ind, batch_ind, sorted_beam_ids], -1)

    # Gather from a tensor with collapsed additional dimensions.
    gather_from = t
    final_shape = array_ops.shape(gather_from)
    gather_from = array_ops.reshape(gather_from,
                                    [max_time, batch_size, beam_width, -1])
    ordered = array_ops.gather_nd(gather_from, indices)
    ordered = array_ops.reshape(ordered, final_shape)

    return ordered
Beispiel #21
0
        def body(time, outputs_ta, state, inputs, finished, sequence_lengths,
                 hypotheses, input_ids, scores, base_index):
            """Internal while_loop body.

            Args:
              time: scalar int32 tensor.
              outputs_ta: structure of TensorArray.
              state: (structure of) state tensors and TensorArrays.
              inputs: (structure of) input tensors.
              finished: bool tensor (keeping track of what's finished).
              sequence_lengths: int32 tensor (keeping track of time of finish).
              hypotheses: structure of TensorArray (stores hypotheses so far).
              input_ids: structure of TensorArray.
              scores: structure of TensorArray.
              base_index:  int32 tensor (keeping track of size of the above 3 TensorArrays)

            Returns:
              `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
                next_sequence_lengths, new_hypotheses, new_input_ids, new_scores, new_base)`.
              ```
            """
            (next_outputs, decoder_state, next_inputs,
             decoder_finished) = decoder.step(time, inputs, state)
            if decoder.tracks_own_finished:
                next_finished = decoder_finished
            else:
                next_finished = math_ops.logical_or(decoder_finished, finished)
            next_sequence_lengths = array_ops.where(
                math_ops.logical_not(next_finished),
                array_ops.fill(array_ops.shape(sequence_lengths),
                               time + 1 + (not decoder._use_go_tokens)),
                sequence_lengths)

            nest.assert_same_structure(state, decoder_state)
            nest.assert_same_structure(outputs_ta, next_outputs)
            nest.assert_same_structure(inputs, next_inputs)

            # Zero out output values past finish
            if impute_finished:
                emit = nest.map_structure(
                    lambda out, zero: array_ops.where(
                        next_finished, zero, out), next_outputs, zero_outputs)
            else:
                emit = next_outputs

            # Copy through states past finish
            def _maybe_copy_state(new, cur):
                # TensorArrays and scalar states get passed through.
                if isinstance(cur, tensor_array_ops.TensorArray):
                    pass_through = True
                else:
                    new.set_shape(cur.shape)
                    pass_through = (new.shape.ndims == 0)
                return new if pass_through else array_ops.where(
                    finished, cur, new)

            outputs_ta = nest.map_structure(
                lambda ta, out: ta.write(time +
                                         (not decoder._use_go_tokens), out),
                outputs_ta, emit)

            # Extract hypotheses, scores for reference
            outputs_so_far = nest.map_structure(lambda ta: ta.stack(),
                                                outputs_ta)
            parent_ids = outputs_so_far.parent_ids
            hypotheses_so_far = outputs_so_far.predicted_ids
            forward_scores = next_outputs.scores

            sl = tf.ones([decoder.batch_size], tf.int32) * \
                (time + 1 + (not decoder._use_go_tokens))
            hypotheses_so_far = beam_search_ops.gather_tree(
                hypotheses_so_far,
                parent_ids,
                max_sequence_lengths=sl,
                end_token=decoder._end_token)

            # Add repetition penalty
            if repetition != 0:

                def unique_counter(x):
                    return tf.cast(tf.size(tf.unique(x)[0]), tf.float32)

                def wrapper_rep_penalty_function(sentence):
                    def first_time_penalty():
                        total_unique_words = unique_counter(sentence)
                        return tf.log(total_unique_words)

                    def generic_penalty():
                        sentence_length = tf.shape(sentence)[0]
                        total_unique_words = unique_counter(sentence)
                        unique_words_before = unique_counter(
                            sentence[:sentence_length - 1])
                        return (tf.log(total_unique_words) -
                                tf.log(unique_words_before))

                    return repetition * tf.cond(math_ops.equal(
                        tf.shape(sentence)[0], 1),
                                                true_fn=first_time_penalty,
                                                false_fn=generic_penalty)

                # Use reshaped hypotheses; calculate penalty per beam per batch
                transposed_hypotheses = tf.transpose(hypotheses_so_far,
                                                     [2, 1, 0])
                repetition_penalty = tf.map_fn(lambda x: tf.map_fn(
                    wrapper_rep_penalty_function, x, dtype=tf.float32),
                                               transposed_hypotheses,
                                               dtype=tf.float32)
                repetition_penalty = tf.transpose(repetition_penalty, [1, 0])
                forward_scores += repetition_penalty

                # Add repetition penalty hypothesis scores
                decoder_state = BeamSearchDecoderState(
                    cell_state=decoder_state.cell_state,
                    log_probs=decoder_state.log_probs + repetition_penalty,
                    finished=decoder_state.finished,
                    lengths=decoder_state.lengths,
                    accumulated_attention_probs=decoder_state.
                    accumulated_attention_probs)

            if impute_finished:
                next_state = nest.map_structure(_maybe_copy_state,
                                                decoder_state, state)
            else:
                next_state = decoder_state

            finished_this_beam = math_ops.logical_and(
                math_ops.logical_not(finished), decoder_finished)
            # Make sure number of outputs is never zero
            finished_this_beam = tf.cond(
                math_ops.logical_and(
                    math_ops.logical_and(
                        tf.equal(hypotheses.size(), 0),
                        tf.equal(tf.size(tf.where(finished_this_beam)), 0)),
                    tf.equal(time, maximum_iterations - 1)),
                true_fn=lambda: tf.cast(tf.ones_like(finished_this_beam),
                                        dtype=tf.bool),
                false_fn=lambda: finished_this_beam)

            def prepare_hypotheses_for_ta():
                finished_beams = tf.where(finished_this_beam)

                hypotheses_for_ta = tf.boolean_mask(
                    tf.transpose(hypotheses_so_far, [1, 2, 0]),
                    finished_this_beam)

                # Pad hypotheses with EOS token
                hypotheses_for_ta = tf.pad(hypotheses_for_ta,
                                           [[0, 0],
                                            [
                                                0, maximum_iterations +
                                                (not decoder._use_go_tokens) -
                                                tf.shape(hypotheses_for_ta)[-1]
                                            ]],
                                           constant_values=decoder._end_token)

                input_query_id = tf.expand_dims(finished_beams[:, 0], 1)
                scores_forward = tf.expand_dims(
                    tf.boolean_mask(forward_scores, finished_this_beam), 1)

                def inner_cond(index, base, hyp_ta, ind_ta, score_ta, hypos,
                               input_ids, forward_scores):
                    # Populate TA with given elements AND do not consider blank responses
                    return math_ops.logical_and(
                        math_ops.less(index,
                                      tf.shape(hypos)[0]),
                        math_ops.greater(time,
                                         0 - (not decoder._use_go_tokens)))

                def inner_body(index, base, hyp_ta, ind_ta, score_ta, hypos,
                               input_ids, forward_scores):
                    new_hyp_ta = nest.map_structure(
                        lambda ta, out: ta.write(base, out), hyp_ta,
                        hypos[index])
                    new_ind_ta = nest.map_structure(
                        lambda ta, out: ta.write(base, out), ind_ta,
                        input_ids[index])

                    # Remove repetition penalty from stored score, use as a feature for later re-reranking
                    forward_scores_store = forward_scores[index]
                    if repetition != 0:
                        forward_scores_store -= repetition * \
                            unique_counter(hypos[index])

                    # Normalize finished scores by their length
                    new_scores_ta = nest.map_structure(
                        lambda ta, out: ta.write(base, out), score_ta,
                        forward_scores_store / tf.cast(
                            tf.count_nonzero(hypos[index] -
                                             decoder._end_token), tf.float32))

                    return (index + 1, base + 1, new_hyp_ta, new_ind_ta,
                            new_scores_ta, hypos, input_ids, forward_scores)

                # Add multiple hypotheses (and related information) to TensorArray using a while_loop
                inner_result = tf.while_loop(
                    inner_cond,
                    inner_body,
                    loop_vars=(tf.constant(0), base_index, hypotheses,
                               input_ids, scores, hypotheses_for_ta,
                               input_query_id, scores_forward),
                    parallel_iterations=parallel_iterations,
                    swap_memory=swap_memory)
                return inner_result[1], inner_result[2], inner_result[
                    3], inner_result[4]

            # In case finished is not True for any beams
            new_base, new_hypotheses, new_input_ids, new_scores = tf.cond(
                math_ops.greater(tf.count_nonzero(finished_this_beam), 0),
                true_fn=prepare_hypotheses_for_ta,
                false_fn=lambda: (base_index, hypotheses, input_ids, scores))

            return (time + 1, outputs_ta, next_state, next_inputs,
                    next_finished, next_sequence_lengths, new_hypotheses,
                    new_input_ids, new_scores, new_base)
Beispiel #22
0
    def finalize(self, outputs, final_state, sequence_lengths):
        """Finalize and return the predicted_ids.

    Args:
      outputs: An instance of BeamSearchDecoderOutput.
      final_state: An instance of BeamSearchDecoderState. Passed through to the
        output.
      sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`. The
        sequence lengths determined for each beam during decode. **NOTE** These
        are ignored; the updated sequence lengths are stored in
        `final_state.lengths`.

    Returns:
      outputs: An instance of `FinalBeamSearchDecoderOutput` where the
        predicted_ids are the result of calling _gather_tree.
      final_state: The same input instance of `BeamSearchDecoderState`.
    """
        del sequence_lengths

        self._decoding_iterations_remaining -= 1  # Decrease counter.

        # Get max_sequence_length across all beams for each batch.
        max_sequence_lengths = math_ops.to_int32(
            math_ops.reduce_max(final_state.lengths, axis=1))
        predicted_ids = beam_search_ops.gather_tree(
            outputs.predicted_ids,
            outputs.parent_ids,
            max_sequence_lengths=max_sequence_lengths,
            end_token=self._end_token)
        if self._reorder_tensor_arrays:
            # pylint: disable=g-long-lambda
            # pylint: disable=line-too-long
            final_state = final_state._replace(cell_state=nest.map_structure(
                lambda t: self._maybe_sort_array_beams(t, outputs.parent_ids,
                                                       final_state.lengths),
                final_state.cell_state))
            # pylint: enable=g-long-lambda
            # pylint: enable=line-too-long
        if self._decoding_iterations_remaining >= 1:

            # Transpose to [batch_size, time, beam_width]
            new_forbidden_tokens = tf.transpose(predicted_ids, perm=[1, 0, 2])
            # Reshape to [batch_size, time * beam_width]
            new_forbidden_tokens = tf.reshape(
                new_forbidden_tokens,
                shape=[tf.shape(new_forbidden_tokens)[0], -1])
            if self._forbidden_tokens is not None:
                self._forbidden_tokens = tf.concat(
                    [self._forbidden_tokens, new_forbidden_tokens], axis=1)
            else:
                self._forbidden_tokens = new_forbidden_tokens

            new_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                self,
                maximum_iterations=self._maximum_iterations,
                output_time_major=True,
                swap_memory=True,
                scope=self._decoder_scope)

            all_scores = tf.concat([
                outputs.scores, new_outputs.beam_search_decoder_output.scores
            ],
                                   axis=2)
            all_predicted_ids = tf.concat([
                outputs.predicted_ids,
                new_outputs.beam_search_decoder_output.predicted_ids
            ],
                                          axis=2)
            all_parent_ids = tf.concat([
                outputs.parent_ids,
                new_outputs.beam_search_decoder_output.parent_ids
            ],
                                       axis=2)
            outputs = beam_search_decoder.BeamSearchDecoderOutput(
                scores=all_scores,
                predicted_ids=all_predicted_ids,
                parent_ids=all_parent_ids)

            # Append eos token ids in case predicted_ids is shorter than new
            # predicted_ids, and vice-versa.
            predicted_ids = pad(x=predicted_ids,
                                max_size=tf.shape(
                                    new_outputs.predicted_ids)[0],
                                value=self._end_token)
            new_predicted_ids = pad(x=new_outputs.predicted_ids,
                                    max_size=tf.shape(predicted_ids)[0],
                                    value=self._end_token)
            predicted_ids = tf.concat([predicted_ids, new_predicted_ids],
                                      axis=2)

        outputs = beam_search_decoder.FinalBeamSearchDecoderOutput(
            beam_search_decoder_output=outputs, predicted_ids=predicted_ids)

        return outputs, final_state