Exemplo n.º 1
0
    def test_greedy_embedding_helper(self):
        """Tests with tf.contrib.seq2seq.GreedyEmbeddingHelper
        """
        decoder = TransformerDecoder(vocab_size=self._vocab_size,
                                     output_layer=self._output_layer)
        decoder.eval()
        helper = decoder_helpers.GreedyEmbeddingHelper(self._embedding,
                                                       self._start_tokens,
                                                       self._end_token)
        outputs, length = decoder(
            memory=self._memory,
            memory_sequence_length=self._memory_sequence_length,
            memory_attention_bias=None,
            helper=helper,
            max_decoding_length=self._max_decode_len)

        self.assertIsInstance(outputs, TransformerDecoderOutput)
Exemplo n.º 2
0
    def test_beam_search(self):
        """Tests beam_search
        """
        decoder = TransformerDecoder(vocab_size=self._vocab_size,
                                     output_layer=self._output_layer)
        decoder.eval()
        outputs = decoder(memory=self._memory,
                          memory_sequence_length=self._memory_sequence_length,
                          memory_attention_bias=None,
                          inputs=None,
                          beam_width=5,
                          start_tokens=self._start_tokens,
                          end_token=self._end_token,
                          max_decoding_length=self._max_decode_len)

        self.assertEqual(outputs['log_prob'].shape, (self._batch_size, 5))
        self.assertEqual(outputs['sample_id'].shape,
                         (self._batch_size, self._max_decode_len, 5))
Exemplo n.º 3
0
    def test_infer_greedy_with_context_without_memory(self):
        """Tests train_greedy with context
        """
        decoder = TransformerDecoder(vocab_size=self._vocab_size,
                                     output_layer=self._output_layer)
        decoder.eval()
        outputs, length = decoder(memory=None,
                                  memory_sequence_length=None,
                                  memory_attention_bias=None,
                                  inputs=None,
                                  decoding_strategy='infer_greedy',
                                  context=self._context,
                                  context_sequence_length=self._context_length,
                                  end_token=self._end_token,
                                  embedding=self._embedding_fn,
                                  max_decoding_length=self._max_decode_len)

        self.assertIsInstance(outputs, TransformerDecoderOutput)
Exemplo n.º 4
0
    def test_decode_infer_sample(self):
        """Tests infer_sample
        """
        decoder = TransformerDecoder(vocab_size=self._vocab_size,
                                     output_layer=self._output_layer)
        decoder.eval()
        helper = decoder_helpers.SampleEmbeddingHelper(self._embedding_fn,
                                                       self._start_tokens,
                                                       self._end_token)

        outputs, length = decoder(
            memory=self._memory,
            memory_sequence_length=self._memory_sequence_length,
            memory_attention_bias=None,
            inputs=None,
            helper=helper,
            max_decoding_length=self._max_decode_len)

        self.assertIsInstance(outputs, TransformerDecoderOutput)
Exemplo n.º 5
0
    def test_beam_search(self):
        """Tests beam_search
        """
        decoder = TransformerDecoder(token_pos_embedder=self._embedding_fn,
                                     vocab_size=self._vocab_size,
                                     output_layer=self._output_layer)
        decoder.eval()
        beam_width = 5
        outputs = decoder(memory=self._memory,
                          memory_sequence_length=self._memory_sequence_length,
                          memory_attention_bias=None,
                          inputs=None,
                          beam_width=beam_width,
                          start_tokens=self._start_tokens,
                          end_token=self._end_token,
                          max_decoding_length=self._max_decode_len)

        self.assertEqual(outputs['log_prob'].size(),
                         (self._batch_size, beam_width))
        self.assertEqual(outputs['sample_id'].size(0), self._batch_size)
        self.assertLessEqual(outputs['sample_id'].size(2),
                             self._max_decode_len)
        self.assertEqual(outputs['sample_id'].size(2), beam_width)