示例#1
0
  def testBeamSearchDecodeTgtPrefix(self, dtype=tf.float32):
    with self.session(use_gpu=True) as sess, self.SetEval(True):
      tf.random.set_seed(_TF_RANDOM_SEED)
      src_batch = 2
      p = self._DecoderParams(dtype=dtype)
      p.init_step_ids = True  # initializes beam search with predefined ids.
      p.beam_search.num_hyps_per_beam = 2
      p.rnn_cell_dim = 32
      dec = decoder.MTDecoderV1(p)
      encoder_outputs, _ = self._Inputs(dtype=dtype, init_step_ids=True)
      decode = dec.BeamSearchDecode(encoder_outputs)
      # topk_decoded is None in MT decoder, set it to a fake tensor to pass
      # sess.run(decode).
      decode = decode._replace(topk_decoded=tf.constant(0, tf.float32))
      tf.global_variables_initializer().run()
      actual_decode = sess.run(decode)

    num_hyps = src_batch * p.beam_search.num_hyps_per_beam
    self.assertTupleEqual((p.target_seq_len, num_hyps),
                          actual_decode.done_hyps.shape)
    self.assertTupleEqual((src_batch, p.beam_search.num_hyps_per_beam),
                          actual_decode.topk_hyps.shape)
    self.assertTupleEqual((num_hyps, p.target_seq_len),
                          actual_decode.topk_ids.shape)
    self.assertTupleEqual((num_hyps,), actual_decode.topk_lens.shape)
    self.assertTupleEqual((src_batch, p.beam_search.num_hyps_per_beam),
                          actual_decode.topk_scores.shape)

    expected_topk_ids = [[2, 0, 0, 0, 0], [13, 2, 0, 0, 0], [0, 0, 0, 0, 0],
                         [0, 0, 0, 0, 0]]
    expected_topk_lens = [1, 2, 0, 0]
    expected_topk_scores = [[-3.783162, -5.767723], [0., 0.]]
    self.assertAllEqual(expected_topk_ids, actual_decode.topk_ids)
    self.assertAllEqual(expected_topk_lens, actual_decode.topk_lens)
    self.assertAllClose(expected_topk_scores, actual_decode.topk_scores)
示例#2
0
    def testDecoderFPropFixedAttentionSeed(self, dtype=tf.float64):
        with self.session(use_gpu=True):
            tf.set_random_seed(_TF_RANDOM_SEED)
            p = self._DecoderParams(dtype=dtype)
            p.feed_attention_context_vec_to_softmax = False
            p.attention.params_init = py_utils.WeightInit.Gaussian(0.1, 12345)
            dec = decoder.MTDecoderV1(p)
            encoder_outputs, targets = self._Inputs(dtype=dtype)
            loss, _ = dec.FPropDefaultTheta(encoder_outputs, targets)['loss']

            tf.global_variables_initializer().run()
            actual_loss = loss.eval()
            print('actual loss = ', actual_loss)
            CompareToGoldenSingleFloat(self, 7.624183, actual_loss)
示例#3
0
    def testBeamSearchDecodeFeedingAttContext(self, dtype=tf.float32):
        tf.set_random_seed(_TF_RANDOM_SEED)
        src_batch = 2
        p = self._DecoderParams(dtype=dtype)
        p.is_eval = True
        src_time = p.target_seq_len
        p.beam_search.num_hyps_per_beam = 2
        p.rnn_cell_dim = 32
        p.feed_attention_context_vec_to_softmax = True
        dec = decoder.MTDecoderV1(p)
        encoder_outputs, _ = self._Inputs(dtype=dtype)
        decode = dec.BeamSearchDecode(encoder_outputs.encoded,
                                      encoder_outputs.padding)
        # topk_decoded is None in MT decoder, set it to a fake tensor to pass
        # sess.run(decode).
        decode = decode._replace(topk_decoded=tf.constant(0, tf.float32))

        with self.session(use_gpu=True) as sess:
            tf.global_variables_initializer().run()
            actual_decode_feeding_att_context = sess.run(decode)

        self.assertTupleEqual(
            (src_time, src_batch * p.beam_search.num_hyps_per_beam),
            actual_decode_feeding_att_context.done_hyps.shape)
        self.assertTupleEqual(
            (src_batch, p.beam_search.num_hyps_per_beam),
            actual_decode_feeding_att_context.topk_hyps.shape)
        self.assertTupleEqual(
            (src_batch * p.beam_search.num_hyps_per_beam, src_time),
            actual_decode_feeding_att_context.topk_ids.shape)
        self.assertTupleEqual(
            (src_batch * p.beam_search.num_hyps_per_beam, ),
            actual_decode_feeding_att_context.topk_lens.shape)
        self.assertTupleEqual(
            (src_batch, p.beam_search.num_hyps_per_beam),
            actual_decode_feeding_att_context.topk_scores.shape)

        expected_topk_ids = [[2, 0, 0, 0, 0], [12, 2, 0, 0, 0],
                             [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]

        expected_topk_lens = [1, 2, 0, 0]
        expected_topk_scores = [[-3.7437, -5.654146], [0., 0.]]

        self.assertAllEqual(expected_topk_ids,
                            actual_decode_feeding_att_context.topk_ids)
        self.assertAllEqual(expected_topk_lens,
                            actual_decode_feeding_att_context.topk_lens)
        self.assertAllClose(expected_topk_scores,
                            actual_decode_feeding_att_context.topk_scores)
示例#4
0
  def testBeamSearchDecode(self, dtype=tf.float32):
    tf.set_random_seed(_TF_RANDOM_SEED)
    src_batch = 2
    p = self._DecoderParams(dtype=dtype)
    p.is_eval = True
    src_time = p.target_seq_len
    p.beam_search.num_hyps_per_beam = 2
    p.rnn_cell_dim = 32
    dec = decoder.MTDecoderV1(p)
    src_enc, src_enc_padding, _ = self._testInputs(dtype=dtype)
    decode = dec.BeamSearchDecode(src_enc, src_enc_padding)
    # topk_decoded is None in MT decoder, set it to a fake tensor to pass
    # sess.run(decode).
    decode = decode._replace(topk_decoded=tf.constant(0, tf.float32))

    with self.session(use_gpu=True) as sess:
      tf.global_variables_initializer().run()
      actual_decode = sess.run(decode)

    self.assertTupleEqual(
        (src_time, src_batch * p.beam_search.num_hyps_per_beam),
        actual_decode.done_hyps.shape)
    self.assertTupleEqual(
        (src_batch, p.beam_search.num_hyps_per_beam),
        actual_decode.topk_hyps.shape)
    self.assertTupleEqual(
        (src_batch * p.beam_search.num_hyps_per_beam, src_time),
        actual_decode.topk_ids.shape)
    self.assertTupleEqual(
        (src_batch * p.beam_search.num_hyps_per_beam,),
        actual_decode.topk_lens.shape)
    self.assertTupleEqual(
        (src_batch, p.beam_search.num_hyps_per_beam),
        actual_decode.topk_scores.shape)

    expected_topk_ids = [[2, 0, 0, 0, 0], [11, 2, 0, 0, 0], [2, 0, 0, 0, 0],
                         [6, 2, 0, 0, 0]]

    expected_topk_lens = [1, 2, 1, 2]
    expected_topk_scores = [[-3.781308, -5.741293], [-3.332158, -5.597181]]

    self.assertAllEqual(expected_topk_ids, actual_decode.topk_ids)
    self.assertAllEqual(expected_topk_lens, actual_decode.topk_lens)
    self.assertAllClose(expected_topk_scores, actual_decode.topk_scores)
示例#5
0
    def testBeamSearchDecodeUseZeroAttenState(self, dtype=tf.float32):
        with self.session(use_gpu=True) as sess, self.SetEval(True):
            tf.set_random_seed(_TF_RANDOM_SEED)
            src_batch = 2
            p = self._DecoderParams(dtype=dtype)
            src_time = p.target_seq_len
            p.beam_search.num_hyps_per_beam = 2
            p.use_zero_atten_state = True
            p.rnn_cell_dim = 32
            dec = decoder.MTDecoderV1(p)
            encoder_outputs, _ = self._Inputs(dtype=dtype)
            decode = dec.BeamSearchDecode(encoder_outputs)
            # topk_decoded is None in MT decoder, set it to a fake tensor to pass
            # sess.run(decode).
            decode = decode._replace(topk_decoded=tf.constant(0, tf.float32))

            tf.global_variables_initializer().run()
            actual_decode = sess.run(decode)

        self.assertTupleEqual(
            (src_time, src_batch * p.beam_search.num_hyps_per_beam),
            actual_decode.done_hyps.shape)
        self.assertTupleEqual((src_batch, p.beam_search.num_hyps_per_beam),
                              actual_decode.topk_hyps.shape)
        self.assertTupleEqual(
            (src_batch * p.beam_search.num_hyps_per_beam, src_time),
            actual_decode.topk_ids.shape)
        self.assertTupleEqual((src_batch * p.beam_search.num_hyps_per_beam, ),
                              actual_decode.topk_lens.shape)
        self.assertTupleEqual((src_batch, p.beam_search.num_hyps_per_beam),
                              actual_decode.topk_scores.shape)

        expected_topk_ids = [[2, 0, 0, 0, 0], [13, 2, 0, 0, 0],
                             [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]

        expected_topk_lens = [1, 2, 0, 0]
        expected_topk_scores = [[-3.783176, -5.767704], [0., 0.]]

        self.assertAllEqual(expected_topk_ids, actual_decode.topk_ids)
        self.assertAllEqual(expected_topk_lens, actual_decode.topk_lens)
        self.assertAllClose(expected_topk_scores, actual_decode.topk_scores)
示例#6
0
 def testDecoderConstruction(self):
     p = self._DecoderParams()
     _ = decoder.MTDecoderV1(p)