Esempio n. 1
0
    def test_transformer_beam1(self):
        batch_size = 2
        beam_size = 1
        alpha = 1.
        max_output_length = 3
        src_mask, embed, decoder, encoder_output, \
        encoder_hidden = self._build(batch_size=batch_size)
        output, attention_scores = beam_search(
            size=beam_size,
            eos_index=self.eos_index,
            pad_index=self.pad_index,
            src_mask=src_mask,
            embed=embed,
            bos_index=self.bos_index,
            max_output_length=max_output_length,
            decoder=decoder,
            alpha=alpha,
            encoder_output=encoder_output,
            encoder_hidden=encoder_hidden)
        # Transformer beam doesn't return attention scores
        self.assertIsNone(attention_scores)
        # batch x time
        self.assertEqual(output.shape, (batch_size, max_output_length))
        np.testing.assert_equal(output, [[5, 5, 5], [5, 5, 5]])

        # now compare to greedy, they should be the same for beam=1
        greedy_output, _ = transformer_greedy(
            src_mask=src_mask,
            embed=embed,
            bos_index=self.bos_index,
            max_output_length=max_output_length,
            decoder=decoder,
            encoder_output=encoder_output,
            encoder_hidden=encoder_hidden)
        np.testing.assert_equal(output, greedy_output)
Esempio n. 2
0
 def test_transformer_greedy(self):
     batch_size = 2
     max_output_length = 3
     src_mask, model, encoder_output, encoder_hidden = self._build(
         batch_size=batch_size)
     output, attention_scores = transformer_greedy(
         src_mask=src_mask,
         max_output_length=max_output_length,
         model=model,
         encoder_output=encoder_output,
         encoder_hidden=encoder_hidden)
     # Transformer greedy doesn't return attention scores
     self.assertIsNone(attention_scores)
     # batch x time
     self.assertEqual(output.shape, (batch_size, max_output_length))
     np.testing.assert_equal(output, [[5, 5, 5], [5, 5, 5]])