def test_greedy_embedding_helper(self): """Tests with tf.contrib.seq2seq.GreedyEmbeddingHelper """ decoder = TransformerDecoder(vocab_size=self._vocab_size, output_layer=self._output_layer) decoder.eval() helper = decoder_helpers.GreedyEmbeddingHelper(self._embedding, self._start_tokens, self._end_token) outputs, length = decoder( memory=self._memory, memory_sequence_length=self._memory_sequence_length, memory_attention_bias=None, helper=helper, max_decoding_length=self._max_decode_len) self.assertIsInstance(outputs, TransformerDecoderOutput)
def test_beam_search(self): """Tests beam_search """ decoder = TransformerDecoder(vocab_size=self._vocab_size, output_layer=self._output_layer) decoder.eval() outputs = decoder(memory=self._memory, memory_sequence_length=self._memory_sequence_length, memory_attention_bias=None, inputs=None, beam_width=5, start_tokens=self._start_tokens, end_token=self._end_token, max_decoding_length=self._max_decode_len) self.assertEqual(outputs['log_prob'].shape, (self._batch_size, 5)) self.assertEqual(outputs['sample_id'].shape, (self._batch_size, self._max_decode_len, 5))
def test_infer_greedy_with_context_without_memory(self): """Tests train_greedy with context """ decoder = TransformerDecoder(vocab_size=self._vocab_size, output_layer=self._output_layer) decoder.eval() outputs, length = decoder(memory=None, memory_sequence_length=None, memory_attention_bias=None, inputs=None, decoding_strategy='infer_greedy', context=self._context, context_sequence_length=self._context_length, end_token=self._end_token, embedding=self._embedding_fn, max_decoding_length=self._max_decode_len) self.assertIsInstance(outputs, TransformerDecoderOutput)
def test_decode_infer_sample(self): """Tests infer_sample """ decoder = TransformerDecoder(vocab_size=self._vocab_size, output_layer=self._output_layer) decoder.eval() helper = decoder_helpers.SampleEmbeddingHelper(self._embedding_fn, self._start_tokens, self._end_token) outputs, length = decoder( memory=self._memory, memory_sequence_length=self._memory_sequence_length, memory_attention_bias=None, inputs=None, helper=helper, max_decoding_length=self._max_decode_len) self.assertIsInstance(outputs, TransformerDecoderOutput)
def test_beam_search(self): """Tests beam_search """ decoder = TransformerDecoder(token_pos_embedder=self._embedding_fn, vocab_size=self._vocab_size, output_layer=self._output_layer) decoder.eval() beam_width = 5 outputs = decoder(memory=self._memory, memory_sequence_length=self._memory_sequence_length, memory_attention_bias=None, inputs=None, beam_width=beam_width, start_tokens=self._start_tokens, end_token=self._end_token, max_decoding_length=self._max_decode_len) self.assertEqual(outputs['log_prob'].size(), (self._batch_size, beam_width)) self.assertEqual(outputs['sample_id'].size(0), self._batch_size) self.assertLessEqual(outputs['sample_id'].size(2), self._max_decode_len) self.assertEqual(outputs['sample_id'].size(2), beam_width)