Ejemplo n.º 1
0
    def testTopKTerminatedHypsOp(self):
        with self.session(use_gpu=False):
            b_size = 8
            num_beams = 2
            num_hyps_per_beam = b_size / num_beams
            seq_len = 6
            scores = tf.random.uniform([b_size, 5], seed=12345)
            atten_probs = tf.random.uniform([b_size, 3], seed=12345)
            src_seq_lengths = [3, 3]
            best_scores = tf.zeros([num_beams])
            cumulative_scores = tf.zeros([b_size])
            in_scores = tf.zeros([seq_len, b_size])
            in_hyps = tf.zeros([seq_len, b_size], dtype=tf.int32)
            in_prev_hyps = tf.zeros([seq_len, b_size], dtype=tf.int32)
            in_done_hyps = tf.as_string(
                tf.zeros([seq_len, b_size], dtype=tf.int32))
            in_atten_probs = tf.zeros([seq_len, b_size, 3])

            (out_best_scores_0, out_cumulative_scores_0, out_scores_0,
             out_hyps_0, out_prev_hyps_0, out_done_hyps_0, out_atten_probs_0,
             _) = ops.beam_search_step(scores,
                                       atten_probs,
                                       best_scores,
                                       cumulative_scores,
                                       in_scores,
                                       in_hyps,
                                       in_prev_hyps,
                                       in_done_hyps,
                                       in_atten_probs, [],
                                       0,
                                       eos_id=2,
                                       beam_size=3.0,
                                       num_hyps_per_beam=num_hyps_per_beam)

            outputs = ops.beam_search_step(scores,
                                           atten_probs,
                                           out_best_scores_0,
                                           out_cumulative_scores_0,
                                           out_scores_0,
                                           out_hyps_0,
                                           out_prev_hyps_0,
                                           out_done_hyps_0,
                                           out_atten_probs_0, [],
                                           1,
                                           eos_id=2,
                                           beam_size=3.0,
                                           num_hyps_per_beam=num_hyps_per_beam)

            # Get the topk terminated hyps.
            in_done_hyps = outputs[5]
            topk_hyps = ops.top_k_terminated_hyps(
                in_done_hyps,
                src_seq_lengths,
                k=2,
                num_hyps_per_beam=num_hyps_per_beam,
                length_normalization=0.2,
                coverage_penalty=0.2,
                target_seq_length_ratio=1.0)
            seq_ids, seq_lens, seq_scores = ops.unpack_hyp(tf.reshape(
                topk_hyps, [-1]),
                                                           max_seq_length=5)

            k1, k2, k3, k4 = self.evaluate(
                [topk_hyps, seq_ids, seq_lens, seq_scores])
            print(np.array_repr(k1))
            assert k1.size == 4

            expected_top1_for_beam_0 = """
      beam_id: 0
      ids: 1
      ids: 2
      scores: 0.86230338
      scores: 0.65504861
      atten_vecs {
        prob: 0.45372832
        prob: 0.86230338
        prob: 0.65504861
      }
      atten_vecs {
        prob: 0.45372832
        prob: 0.86230338
        prob: 0.65504861
      }
      normalized_score: 1.002714
      """
            expected_top2_for_beam_1 = """
      beam_id: 1
      ids: 3
      ids: 2
      scores: 0.38127339
      scores: 0.57700801
      atten_vecs {
        prob: 0.38612545
        prob: 0.42067075
        prob: 0.84442794
      }
      atten_vecs {
        prob: 0.18693292
        prob: 0.17821217
        prob: 0.66380036
      }
      normalized_score: 0.480028
      """
            self._SameHyp(expected_top1_for_beam_0, k1[0, 0])
            self._SameHyp(expected_top2_for_beam_1, k1[1, 1])

            self.assertAllClose(k2, [[1, 2, 0, 0, 0], [4, 2, 0, 0, 0],
                                     [4, 2, 0, 0, 0], [3, 2, 0, 0, 0]])
            self.assertAllClose(k3, [2, 2, 2, 2])
            self.assertAllClose(k4, [1.002714, 0.684296, 0.522484, 0.480028])
Ejemplo n.º 2
0
  def BeamSearchDecode(self,
                       theta,
                       encoder_outputs,
                       num_hyps_per_beam_override=0,
                       init_beam_search_state=None,
                       pre_beam_search_step_callback=None,
                       post_beam_search_step_callback=None,
                       max_steps=None):
    """Performs beam-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap containing encoder outputs to be passed to
        the callbacks.
      num_hyps_per_beam_override: If set to a value <= 0, this parameter is
        ignored. If set to a value > 0, then this value will be used to override
        `p.num_hyps_per_beam`.
      init_beam_search_state: The `InitBeamSearchState` callback. Please refer
        to the class header comments for more details.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
        self.params.target_seq_len.

    Returns:
      A `BeamSearchDecodeOutput`.
    """
    p = self.params
    num_hyps_per_beam = p.num_hyps_per_beam
    if num_hyps_per_beam_override > 0:
      num_hyps_per_beam = num_hyps_per_beam_override
    if max_steps is None:
      max_steps = p.target_seq_len

    initial_results, other_states = init_beam_search_state(
        theta, encoder_outputs, num_hyps_per_beam)

    num_hyps = tf.shape(initial_results.log_probs)[0]
    num_beams = num_hyps // num_hyps_per_beam

    if 'step_ids' in initial_results:
      # [num_hyps, 1]
      step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1])
    else:
      step_ids = tf.fill([num_hyps, 1],
                         tf.constant(p.target_sos_id, dtype=tf.int32))

    min_score = -1e36
    best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score)
    cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype)
    in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype)
    in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
    in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
    in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string)
    bs_atten_probs = tf.zeros(
        [max_steps, num_hyps,
         tf.shape(initial_results.atten_probs)[1]],
        dtype=p.dtype)
    cur_step = tf.constant(0, dtype=tf.int32)
    all_done = tf.constant(False, dtype=tf.bool)
    core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps,
                      in_prev_hyps, in_done_hyps, bs_atten_probs)

    def LoopContinue(cur_step, all_done, unused_step_ids, unused_core_bs_states,
                     unused_other_states_list):
      return tf.logical_and(cur_step < max_steps, tf.logical_not(all_done))

    def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states,
                 other_states_list):
      (cur_step, all_done, new_step_ids, new_bs_states,
       new_other_states) = self._BeamSearchStep(
           theta, encoder_outputs, cur_step, step_ids, core_bs_states,
           other_states.Pack(other_states_list), num_hyps_per_beam,
           pre_beam_search_step_callback, post_beam_search_step_callback)
      return (cur_step, all_done, new_step_ids, new_bs_states,
              new_other_states.Flatten())

    flat_other_states = other_states.Flatten()
    _, _, _, final_bs_states, flat_final_other_states = tf.while_loop(
        LoopContinue,
        LoopBody,
        loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                   flat_other_states),
        parallel_iterations=10,
        back_prop=False,
        swap_memory=False,
        shape_invariants=(tf.TensorShape(cur_step.get_shape()),
                          tf.TensorShape(all_done.get_shape()),
                          tf.TensorShape(step_ids.get_shape()),
                          _GetShapes(core_bs_states),
                          _GetShapes(flat_other_states, none_shapes=True)))
    # [target_seq_len, num_beams * num_hyps_per_beam].
    final_done_hyps = final_bs_states[5]
    final_other_states = other_states.Pack(flat_final_other_states)

    # TODO(rpang): avoid inspecting 'encoder_outputs'.
    source_paddings = encoder_outputs.padding
    if isinstance(source_paddings, py_utils.NestedMap):
      source_seq_lengths = tf.cast(
          tf.round(
              tf.reduce_sum(1.0 - tf.transpose(source_paddings.Flatten()[0]),
                            1)), tf.int32)
    else:
      source_seq_lengths = tf.cast(
          tf.round(tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)),
          tf.int32)

    # [num_beams, num_hyps_per_beam].
    topk_hyps = ops.top_k_terminated_hyps(
        final_done_hyps,
        source_seq_lengths,
        k=num_hyps_per_beam,
        num_hyps_per_beam=num_hyps_per_beam,
        length_normalization=p.length_normalization,
        coverage_penalty=p.coverage_penalty,
        target_seq_length_ratio=p.target_seq_length_ratio,
        eoc_id=p.target_eoc_id,
        merge_paths=p.merge_paths)
    # [num_beams * num_hyps_per_beam, ...].
    max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps
    topk_ids, topk_lens, topk_scores = ops.unpack_hyp(
        tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length)
    # [num_beams, num_hyps_per_beam].
    topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps))

    return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids,
                                  topk_lens, topk_scores, None,
                                  final_other_states)
Ejemplo n.º 3
0
    def testTopKEquivalent(self, length_normalization, coverage_penalty,
                           length_ratio, populate_hyps):
        """Tests that top_k_from_beam_search_outs is indeed equivalent."""
        with self.session(use_gpu=False) as sess:
            hyp_size = 32
            num_beams = 8
            num_hyps_per_beam = hyp_size // num_beams
            seq_len = 10

            hyps = np.random.randint(3, 100, size=[seq_len, hyp_size])
            # We align all the hyps to make cumulative_score easy to compute.
            prev_hyps = np.tile(np.arange(hyp_size), [seq_len, 1])
            done_hyps = np.ones([seq_len, hyp_size], dtype=np.bool)
            scores = np.random.uniform(-0.5, 1, size=[seq_len, hyp_size])
            cumulative_scores = np.cumsum(scores, axis=0)
            eos_scores = np.random.uniform(-0.5, 1, size=[seq_len, hyp_size])
            atten_probs = np.random.uniform(0,
                                            .05,
                                            size=[seq_len, hyp_size, seq_len])
            eos_atten_probs = np.random.uniform(
                0, .05, size=[seq_len, hyp_size, seq_len])
            cum_atten_probs = np.cumsum(atten_probs, axis=0)
            cum_atten_probs = np.pad(cum_atten_probs, ((1, 0), (0, 0), (0, 0)),
                                     'constant')[:seq_len, :, :]
            cum_atten_probs = cum_atten_probs + eos_atten_probs
            results = ops.top_k_from_beam_search_outs(
                hyps,
                prev_hyps,
                done_hyps,
                cumulative_scores,
                eos_scores,
                scores=scores if populate_hyps else 0,
                atten_probs=atten_probs if populate_hyps else 0,
                eos_atten_probs=eos_atten_probs if populate_hyps else 0,
                cumulative_atten_probs=cum_atten_probs
                if coverage_penalty > 0 else 0,
                length_normalization=length_normalization,
                coverage_penalty=coverage_penalty,
                num_hyps_per_beam=num_hyps_per_beam,
                max_seq_length=seq_len,
                target_seq_length_ratio=length_ratio,
                populate_topk_hyps=populate_hyps,
            )
            outs = sess.run(results)

            final_done_hyps = ops.hyps_from_beam_search_outs(
                hyps,
                prev_hyps,
                done_hyps,
                scores,
                atten_probs,
                eos_scores,
                eos_atten_probs,
                eos_id=2,
                num_hyps_per_beam=num_hyps_per_beam,
            )
            src_seq_lengths = np.ones([num_beams], dtype=np.int32) * seq_len
            topk_hyps = ops.top_k_terminated_hyps(
                final_done_hyps,
                src_seq_lengths,
                k=num_hyps_per_beam,
                num_hyps_per_beam=num_hyps_per_beam,
                length_normalization=length_normalization,
                coverage_penalty=coverage_penalty,
                target_seq_length_ratio=length_ratio)
            topk_ids, topk_lens, topk_scores = ops.unpack_hyp(
                topk_hyps, max_seq_length=seq_len)
            topk_hyps, topk_ids, topk_lens, topk_scores = sess.run(
                [topk_hyps, topk_ids, topk_lens, topk_scores])

        self.assertAllEqual(outs[0], topk_ids)
        self.assertAllEqual(outs[1], topk_lens)
        self.assertAllClose(outs[2], topk_scores)
        if populate_hyps:
            self.assertAllEqual(outs[3].shape, topk_hyps.shape)
            self.assertAllEqual(outs[3].shape, [num_beams, num_hyps_per_beam])
            for i in range(num_beams):
                for j in range(num_hyps_per_beam):
                    self._SameHyp(outs[3][i, j], topk_hyps[i, j])