示例#1
0
    def _runBeamSearchOpHelper(self,
                               b_size,
                               num_beams,
                               seq_len,
                               init_best_score,
                               probs,
                               init_atten_probs,
                               atten_probs,
                               beam_size=3.0,
                               ensure_full_beam=False,
                               force_eos_in_last_step=False,
                               local_eos_threshold=-100.0):
        eos_id = 2
        num_hyps_per_beam = b_size / num_beams

        best_scores = tf.zeros([num_beams])
        cumulative_scores = tf.zeros([b_size])
        scores = tf.zeros([seq_len, b_size])
        hyps = tf.zeros([seq_len, b_size], dtype=tf.int32)
        prev_hyps = tf.zeros([seq_len, b_size], dtype=tf.int32)
        done_hyps = tf.as_string(tf.zeros([seq_len, b_size], dtype=tf.int32))
        best_scores += init_best_score

        for i, prob in enumerate(probs):
            (best_scores, cumulative_scores, scores, hyps, prev_hyps,
             done_hyps, atten_probs, done) = ops.beam_search_step(
                 prob,
                 init_atten_probs,
                 best_scores,
                 cumulative_scores,
                 scores,
                 hyps,
                 prev_hyps,
                 done_hyps,
                 atten_probs, [],
                 i,
                 eos_id=eos_id,
                 beam_size=beam_size,
                 ensure_full_beam=ensure_full_beam,
                 num_hyps_per_beam=num_hyps_per_beam,
                 valid_eos_max_logit_delta=0.1,
                 force_eos_in_last_step=force_eos_in_last_step,
                 local_eos_threshold=local_eos_threshold)

        with self.session(use_gpu=False):
            (best_scores, cumulative_scores, scores, hyps, prev_hyps,
             done_hyps, atten_probs, done, scores,
             atten_probs) = self.evaluate([
                 best_scores, cumulative_scores, scores, hyps, prev_hyps,
                 done_hyps, atten_probs, done, scores, atten_probs
             ])

        return (best_scores, cumulative_scores, scores, hyps, prev_hyps,
                done_hyps, atten_probs, done, scores, atten_probs)
  def _BeamSearchStep(self, theta, encoder_outputs, cur_step, step_ids,
                      core_bs_states, other_states, num_hyps_per_beam,
                      pre_beam_search_step_callback,
                      post_beam_search_step_callback):
    """Extend beam search hyps for one step.

      | num_beams = Number of source sequences to be decoded.
      | num_hyps_per_beam = Number of hyps to keep per source sequence.
      | num_hyps = num_beams * num_hyps_per_beam
      | src_seq_len = Number of time steps in the source sequence.
      | src_batch = Number of examples in the source sequence.
      | tgt_seq_len = Maximum allowed time steps in the target sequence.
      | tgt_batch = num_hyps_per_beam * src_batch

    Args:
      theta: A `.NestedMap` object containing weights' values of the decoder
        layer and its children layers.
      encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to
        the callbacks.
      cur_step: A scalar int tensor, the current time step, 0-based.
      step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the
        current search step.
      core_bs_states: A tuple of core beam search states. This list is
        maintained by this helper class.
      other_states: A `.NestedMap` of other beam search states. This
        `.NestedMap` is managed and updated by the client. It is expected that
        each of its member tensors are of rank >= 1. t[i, ...] is the state of
        the i-th hyp at the beginning of this search step.
      num_hyps_per_beam: Num of hyps to keep per beam.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        See class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        See class header comments for more details.

    Returns:
      A tuple of following elements for the next beam search step,
      (next step, all_done, step_ids, core_bs_states, other_states)
    """
    p = self.params

    bs_results, other_states = pre_beam_search_step_callback(
        theta, encoder_outputs, step_ids, other_states, num_hyps_per_beam)

    (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps,
     in_done_hyps, in_atten_probs) = core_bs_states

    (out_best_scores, out_cumulative_scores, out_scores, out_hyps,
     out_prev_hyps, out_done_hyps, out_atten_probs,
     all_done) = ops.beam_search_step(
         bs_results.log_probs,
         bs_results.atten_probs,
         best_scores,
         cumulative_scores,
         in_scores,
         in_hyps,
         in_prev_hyps,
         in_done_hyps,
         in_atten_probs,
         bs_results.is_last_chunk if self._model_uses_eoc_id else [],
         cur_step,
         eoc_id=p.target_eoc_id,
         eos_id=p.target_eos_id,
         beam_size=p.beam_size,
         num_hyps_per_beam=num_hyps_per_beam,
         valid_eos_max_logit_delta=p.valid_eos_max_logit_delta,
         merge_paths=p.merge_paths,
         allow_empty_terminated_hyp=p.allow_empty_terminated_hyp,
         ensure_full_beam=p.ensure_full_beam,
         force_eos_in_last_step=p.force_eos_in_last_step,
         local_eos_threshold=p.local_eos_threshold)

    new_step_ids = tf.reshape(out_hyps[cur_step, :], tf.shape(step_ids))
    new_step_ids.set_shape(step_ids.get_shape())

    old_hyp_ids = tf.reshape(
        tf.slice(out_prev_hyps, begin=[cur_step, 0], size=[1, -1]), [-1])

    if p.batch_major_compute:
      # Transformed the indices into the key/value cache for fast decoding
      # (prefix_states in other_states) due to the num_hyps dimension of
      # cache is computed as num_beams by num_hyps_per_beam, which is different
      # from the old_hyp_ids assumption (num_hyps_per_beam by num_beams).
      # Both transpose and recomputation are required to correct the indices.
      num_beams = tf.shape(best_scores)[0]
      old_hyp_ids_in_cache_order = tf.reshape(
          tf.transpose(tf.reshape(old_hyp_ids, [num_hyps_per_beam, -1])), [-1])
      old_hyp_ids_in_cache_order = (
          (old_hyp_ids_in_cache_order % num_beams) * num_hyps_per_beam +
          old_hyp_ids_in_cache_order // num_beams)

    new_bs_states = (out_best_scores, out_cumulative_scores, out_scores,
                     out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs)

    def ReOrderHyps(x_in):
      """Reorders x_in based on prev hyp ids."""
      if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims and
          x_in.shape.ndims > 0):
        if x_in.shape.ndims > 2 and not p.batch_major_state:
          # Use corrected indices only here for batch major compute as key/value
          # caches are the states being affected.
          correct_old_hyp_ids = (
              old_hyp_ids_in_cache_order
              if p.batch_major_compute else old_hyp_ids)
          x_out = tf.gather(x_in, correct_old_hyp_ids, axis=1)
        else:
          x_out = tf.gather(x_in, old_hyp_ids)
        x_out.set_shape(x_in.get_shape())
        return x_out
      else:
        return x_in

    new_other_states = other_states.Transform(ReOrderHyps)

    final_other_states = post_beam_search_step_callback(theta, encoder_outputs,
                                                        new_step_ids,
                                                        new_other_states)

    return (cur_step + 1, all_done, new_step_ids, new_bs_states,
            final_other_states)
示例#3
0
    def testTopKTerminatedHypsOp(self):
        with self.session(use_gpu=False):
            b_size = 8
            num_beams = 2
            num_hyps_per_beam = b_size / num_beams
            seq_len = 6
            scores = tf.random.uniform([b_size, 5], seed=12345)
            atten_probs = tf.random.uniform([b_size, 3], seed=12345)
            src_seq_lengths = [3, 3]
            best_scores = tf.zeros([num_beams])
            cumulative_scores = tf.zeros([b_size])
            in_scores = tf.zeros([seq_len, b_size])
            in_hyps = tf.zeros([seq_len, b_size], dtype=tf.int32)
            in_prev_hyps = tf.zeros([seq_len, b_size], dtype=tf.int32)
            in_done_hyps = tf.as_string(
                tf.zeros([seq_len, b_size], dtype=tf.int32))
            in_atten_probs = tf.zeros([seq_len, b_size, 3])

            (out_best_scores_0, out_cumulative_scores_0, out_scores_0,
             out_hyps_0, out_prev_hyps_0, out_done_hyps_0, out_atten_probs_0,
             _) = ops.beam_search_step(scores,
                                       atten_probs,
                                       best_scores,
                                       cumulative_scores,
                                       in_scores,
                                       in_hyps,
                                       in_prev_hyps,
                                       in_done_hyps,
                                       in_atten_probs, [],
                                       0,
                                       eos_id=2,
                                       beam_size=3.0,
                                       num_hyps_per_beam=num_hyps_per_beam)

            outputs = ops.beam_search_step(scores,
                                           atten_probs,
                                           out_best_scores_0,
                                           out_cumulative_scores_0,
                                           out_scores_0,
                                           out_hyps_0,
                                           out_prev_hyps_0,
                                           out_done_hyps_0,
                                           out_atten_probs_0, [],
                                           1,
                                           eos_id=2,
                                           beam_size=3.0,
                                           num_hyps_per_beam=num_hyps_per_beam)

            # Get the topk terminated hyps.
            in_done_hyps = outputs[5]
            topk_hyps = ops.top_k_terminated_hyps(
                in_done_hyps,
                src_seq_lengths,
                k=2,
                num_hyps_per_beam=num_hyps_per_beam,
                length_normalization=0.2,
                coverage_penalty=0.2,
                target_seq_length_ratio=1.0)
            seq_ids, seq_lens, seq_scores = ops.unpack_hyp(tf.reshape(
                topk_hyps, [-1]),
                                                           max_seq_length=5)

            k1, k2, k3, k4 = self.evaluate(
                [topk_hyps, seq_ids, seq_lens, seq_scores])
            print(np.array_repr(k1))
            assert k1.size == 4

            expected_top1_for_beam_0 = """
      beam_id: 0
      ids: 1
      ids: 2
      scores: 0.86230338
      scores: 0.65504861
      atten_vecs {
        prob: 0.45372832
        prob: 0.86230338
        prob: 0.65504861
      }
      atten_vecs {
        prob: 0.45372832
        prob: 0.86230338
        prob: 0.65504861
      }
      normalized_score: 1.002714
      """
            expected_top2_for_beam_1 = """
      beam_id: 1
      ids: 3
      ids: 2
      scores: 0.38127339
      scores: 0.57700801
      atten_vecs {
        prob: 0.38612545
        prob: 0.42067075
        prob: 0.84442794
      }
      atten_vecs {
        prob: 0.18693292
        prob: 0.17821217
        prob: 0.66380036
      }
      normalized_score: 0.480028
      """
            self._SameHyp(expected_top1_for_beam_0, k1[0, 0])
            self._SameHyp(expected_top2_for_beam_1, k1[1, 1])

            self.assertAllClose(k2, [[1, 2, 0, 0, 0], [4, 2, 0, 0, 0],
                                     [4, 2, 0, 0, 0], [3, 2, 0, 0, 0]])
            self.assertAllClose(k3, [2, 2, 2, 2])
            self.assertAllClose(k4, [1.002714, 0.684296, 0.522484, 0.480028])
示例#4
0
    def _runBeamSearchOpHelper(self,
                               hyp_size,
                               num_beams,
                               seq_len,
                               init_best_score,
                               probs,
                               init_atten_probs,
                               atten_probs,
                               beam_size=3.0,
                               ensure_full_beam=False,
                               force_eos_in_last_step=False,
                               local_eos_threshold=-100.0,
                               independence=True,
                               use_v2=True):
        eos_id = 2
        num_hyps_per_beam = hyp_size / num_beams

        best_scores = tf.zeros([num_beams])
        cumulative_scores = tf.zeros([hyp_size])
        scores = tf.zeros([seq_len, hyp_size])
        hyps = tf.zeros([seq_len, hyp_size], dtype=tf.int32)
        prev_hyps = tf.zeros([seq_len, hyp_size], dtype=tf.int32)
        done_hyps = tf.constant('', shape=[seq_len, hyp_size], dtype=tf.string)
        best_scores += init_best_score
        beam_done = tf.zeros([num_beams], dtype=tf.bool)

        for i, prob in enumerate(probs):
            if use_v2:
                (best_scores, cumulative_scores, scores, hyps, prev_hyps,
                 done_hyps, atten_probs, beam_done,
                 done) = ops.beam_search_step(
                     prob,
                     init_atten_probs,
                     best_scores,
                     cumulative_scores,
                     scores,
                     hyps,
                     prev_hyps,
                     done_hyps,
                     atten_probs,
                     beam_done, [],
                     i,
                     eos_id=eos_id,
                     beam_size=beam_size,
                     ensure_full_beam=ensure_full_beam,
                     num_hyps_per_beam=num_hyps_per_beam,
                     valid_eos_max_logit_delta=0.1,
                     force_eos_in_last_step=force_eos_in_last_step,
                     local_eos_threshold=local_eos_threshold,
                     beam_independence=independence)
            else:
                (best_scores, cumulative_scores, scores, hyps, prev_hyps,
                 done_hyps, atten_probs,
                 done) = ops.beam_search_step_deprecated(
                     prob,
                     init_atten_probs,
                     best_scores,
                     cumulative_scores,
                     scores,
                     hyps,
                     prev_hyps,
                     done_hyps,
                     atten_probs, [],
                     i,
                     eos_id=eos_id,
                     beam_size=beam_size,
                     ensure_full_beam=ensure_full_beam,
                     num_hyps_per_beam=num_hyps_per_beam,
                     valid_eos_max_logit_delta=0.1,
                     force_eos_in_last_step=force_eos_in_last_step,
                     local_eos_threshold=local_eos_threshold)

        with self.session(use_gpu=False):
            (best_scores, cumulative_scores, scores, hyps, prev_hyps,
             done_hyps, atten_probs, done, beam_done) = self.evaluate([
                 best_scores, cumulative_scores, scores, hyps, prev_hyps,
                 done_hyps, atten_probs, done, beam_done
             ])

        return (best_scores, cumulative_scores, scores, hyps, prev_hyps,
                done_hyps, atten_probs, done, beam_done)