def testBeamSearchDecodeTgtPrefix(self, dtype=tf.float32): with self.session(use_gpu=True) as sess, self.SetEval(True): tf.random.set_seed(_TF_RANDOM_SEED) src_batch = 2 p = self._DecoderParams(dtype=dtype) p.init_step_ids = True # initializes beam search with predefined ids. p.beam_search.num_hyps_per_beam = 2 p.rnn_cell_dim = 32 dec = decoder.MTDecoderV1(p) encoder_outputs, _ = self._Inputs(dtype=dtype, init_step_ids=True) decode = dec.BeamSearchDecode(encoder_outputs) # topk_decoded is None in MT decoder, set it to a fake tensor to pass # sess.run(decode). decode = decode._replace(topk_decoded=tf.constant(0, tf.float32)) tf.global_variables_initializer().run() actual_decode = sess.run(decode) num_hyps = src_batch * p.beam_search.num_hyps_per_beam self.assertTupleEqual((p.target_seq_len, num_hyps), actual_decode.done_hyps.shape) self.assertTupleEqual((src_batch, p.beam_search.num_hyps_per_beam), actual_decode.topk_hyps.shape) self.assertTupleEqual((num_hyps, p.target_seq_len), actual_decode.topk_ids.shape) self.assertTupleEqual((num_hyps,), actual_decode.topk_lens.shape) self.assertTupleEqual((src_batch, p.beam_search.num_hyps_per_beam), actual_decode.topk_scores.shape) expected_topk_ids = [[2, 0, 0, 0, 0], [13, 2, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] expected_topk_lens = [1, 2, 0, 0] expected_topk_scores = [[-3.783162, -5.767723], [0., 0.]] self.assertAllEqual(expected_topk_ids, actual_decode.topk_ids) self.assertAllEqual(expected_topk_lens, actual_decode.topk_lens) self.assertAllClose(expected_topk_scores, actual_decode.topk_scores)
def testDecoderFPropFixedAttentionSeed(self, dtype=tf.float64): with self.session(use_gpu=True): tf.set_random_seed(_TF_RANDOM_SEED) p = self._DecoderParams(dtype=dtype) p.feed_attention_context_vec_to_softmax = False p.attention.params_init = py_utils.WeightInit.Gaussian(0.1, 12345) dec = decoder.MTDecoderV1(p) encoder_outputs, targets = self._Inputs(dtype=dtype) loss, _ = dec.FPropDefaultTheta(encoder_outputs, targets)['loss'] tf.global_variables_initializer().run() actual_loss = loss.eval() print('actual loss = ', actual_loss) CompareToGoldenSingleFloat(self, 7.624183, actual_loss)
def testBeamSearchDecodeFeedingAttContext(self, dtype=tf.float32): tf.set_random_seed(_TF_RANDOM_SEED) src_batch = 2 p = self._DecoderParams(dtype=dtype) p.is_eval = True src_time = p.target_seq_len p.beam_search.num_hyps_per_beam = 2 p.rnn_cell_dim = 32 p.feed_attention_context_vec_to_softmax = True dec = decoder.MTDecoderV1(p) encoder_outputs, _ = self._Inputs(dtype=dtype) decode = dec.BeamSearchDecode(encoder_outputs.encoded, encoder_outputs.padding) # topk_decoded is None in MT decoder, set it to a fake tensor to pass # sess.run(decode). decode = decode._replace(topk_decoded=tf.constant(0, tf.float32)) with self.session(use_gpu=True) as sess: tf.global_variables_initializer().run() actual_decode_feeding_att_context = sess.run(decode) self.assertTupleEqual( (src_time, src_batch * p.beam_search.num_hyps_per_beam), actual_decode_feeding_att_context.done_hyps.shape) self.assertTupleEqual( (src_batch, p.beam_search.num_hyps_per_beam), actual_decode_feeding_att_context.topk_hyps.shape) self.assertTupleEqual( (src_batch * p.beam_search.num_hyps_per_beam, src_time), actual_decode_feeding_att_context.topk_ids.shape) self.assertTupleEqual( (src_batch * p.beam_search.num_hyps_per_beam, ), actual_decode_feeding_att_context.topk_lens.shape) self.assertTupleEqual( (src_batch, p.beam_search.num_hyps_per_beam), actual_decode_feeding_att_context.topk_scores.shape) expected_topk_ids = [[2, 0, 0, 0, 0], [12, 2, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] expected_topk_lens = [1, 2, 0, 0] expected_topk_scores = [[-3.7437, -5.654146], [0., 0.]] self.assertAllEqual(expected_topk_ids, actual_decode_feeding_att_context.topk_ids) self.assertAllEqual(expected_topk_lens, actual_decode_feeding_att_context.topk_lens) self.assertAllClose(expected_topk_scores, actual_decode_feeding_att_context.topk_scores)
def testBeamSearchDecode(self, dtype=tf.float32): tf.set_random_seed(_TF_RANDOM_SEED) src_batch = 2 p = self._DecoderParams(dtype=dtype) p.is_eval = True src_time = p.target_seq_len p.beam_search.num_hyps_per_beam = 2 p.rnn_cell_dim = 32 dec = decoder.MTDecoderV1(p) src_enc, src_enc_padding, _ = self._testInputs(dtype=dtype) decode = dec.BeamSearchDecode(src_enc, src_enc_padding) # topk_decoded is None in MT decoder, set it to a fake tensor to pass # sess.run(decode). decode = decode._replace(topk_decoded=tf.constant(0, tf.float32)) with self.session(use_gpu=True) as sess: tf.global_variables_initializer().run() actual_decode = sess.run(decode) self.assertTupleEqual( (src_time, src_batch * p.beam_search.num_hyps_per_beam), actual_decode.done_hyps.shape) self.assertTupleEqual( (src_batch, p.beam_search.num_hyps_per_beam), actual_decode.topk_hyps.shape) self.assertTupleEqual( (src_batch * p.beam_search.num_hyps_per_beam, src_time), actual_decode.topk_ids.shape) self.assertTupleEqual( (src_batch * p.beam_search.num_hyps_per_beam,), actual_decode.topk_lens.shape) self.assertTupleEqual( (src_batch, p.beam_search.num_hyps_per_beam), actual_decode.topk_scores.shape) expected_topk_ids = [[2, 0, 0, 0, 0], [11, 2, 0, 0, 0], [2, 0, 0, 0, 0], [6, 2, 0, 0, 0]] expected_topk_lens = [1, 2, 1, 2] expected_topk_scores = [[-3.781308, -5.741293], [-3.332158, -5.597181]] self.assertAllEqual(expected_topk_ids, actual_decode.topk_ids) self.assertAllEqual(expected_topk_lens, actual_decode.topk_lens) self.assertAllClose(expected_topk_scores, actual_decode.topk_scores)
def testBeamSearchDecodeUseZeroAttenState(self, dtype=tf.float32): with self.session(use_gpu=True) as sess, self.SetEval(True): tf.set_random_seed(_TF_RANDOM_SEED) src_batch = 2 p = self._DecoderParams(dtype=dtype) src_time = p.target_seq_len p.beam_search.num_hyps_per_beam = 2 p.use_zero_atten_state = True p.rnn_cell_dim = 32 dec = decoder.MTDecoderV1(p) encoder_outputs, _ = self._Inputs(dtype=dtype) decode = dec.BeamSearchDecode(encoder_outputs) # topk_decoded is None in MT decoder, set it to a fake tensor to pass # sess.run(decode). decode = decode._replace(topk_decoded=tf.constant(0, tf.float32)) tf.global_variables_initializer().run() actual_decode = sess.run(decode) self.assertTupleEqual( (src_time, src_batch * p.beam_search.num_hyps_per_beam), actual_decode.done_hyps.shape) self.assertTupleEqual((src_batch, p.beam_search.num_hyps_per_beam), actual_decode.topk_hyps.shape) self.assertTupleEqual( (src_batch * p.beam_search.num_hyps_per_beam, src_time), actual_decode.topk_ids.shape) self.assertTupleEqual((src_batch * p.beam_search.num_hyps_per_beam, ), actual_decode.topk_lens.shape) self.assertTupleEqual((src_batch, p.beam_search.num_hyps_per_beam), actual_decode.topk_scores.shape) expected_topk_ids = [[2, 0, 0, 0, 0], [13, 2, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] expected_topk_lens = [1, 2, 0, 0] expected_topk_scores = [[-3.783176, -5.767704], [0., 0.]] self.assertAllEqual(expected_topk_ids, actual_decode.topk_ids) self.assertAllEqual(expected_topk_lens, actual_decode.topk_lens) self.assertAllClose(expected_topk_scores, actual_decode.topk_scores)
def testDecoderConstruction(self): p = self._DecoderParams() _ = decoder.MTDecoderV1(p)