Esempio n. 1
0
    def test_recurrent_greedy(self):
        batch_size = 2
        max_output_length = 3
        src_mask, emb, decoder, encoder_output, encoder_hidden = self._build(
            batch_size=batch_size)

        output, attention_scores = recurrent_greedy(
            src_mask=src_mask,
            embed=emb,
            bos_index=self.bos_index,
            eos_index=self.eos_index,
            max_output_length=max_output_length,
            decoder=decoder,
            encoder_output=encoder_output,
            encoder_hidden=encoder_hidden)

        self.assertEqual(output.shape, (batch_size, max_output_length))
        np.testing.assert_equal(output, [[4, 0, 4], [4, 4, 4]])

        expected_attention_scores = np.array(
            [[[0.22914883, 0.24638498, 0.21247596, 0.3119903],
              [0.22970565, 0.24540883, 0.21261126, 0.31227428],
              [0.22903332, 0.2459198, 0.2110187, 0.3140282]],
             [[0.252522, 0.29074305, 0.257121, 0.19961396],
              [0.2519883, 0.2895494, 0.25718424, 0.201278],
              [0.2523954, 0.28959078, 0.25769445, 0.2003194]]])
        np.testing.assert_array_almost_equal(attention_scores,
                                             expected_attention_scores)
        self.assertEqual(attention_scores.shape,
                         (batch_size, max_output_length, 4))
Esempio n. 2
0
    def test_recurrent_beam1(self):
        # beam=1 and greedy should return the same result
        batch_size = 2
        max_output_length = 3
        src_mask, emb, decoder, encoder_output, encoder_hidden = self._build(
            batch_size=batch_size)

        greedy_output, _ = recurrent_greedy(
            src_mask=src_mask,
            embed=emb,
            bos_index=self.bos_index,
            eos_index=self.eos_index,
            max_output_length=max_output_length,
            decoder=decoder,
            encoder_output=encoder_output,
            encoder_hidden=encoder_hidden)

        beam_size = 1
        alpha = 1.0
        output, _ = beam_search(size=beam_size,
                                eos_index=self.eos_index,
                                pad_index=self.pad_index,
                                src_mask=src_mask,
                                embed=emb,
                                bos_index=self.bos_index,
                                n_best=1,
                                max_output_length=max_output_length,
                                decoder=decoder,
                                alpha=alpha,
                                encoder_output=encoder_output,
                                encoder_hidden=encoder_hidden)
        np.testing.assert_array_equal(greedy_output, output)
Esempio n. 3
0
    def test_recurrent_beam1(self):
        # beam=1 and greedy should return the same result
        batch_size = 2
        max_output_length = 3
        src_mask, model, encoder_output, encoder_hidden = self._build(
            batch_size=batch_size)

        greedy_output, _ = recurrent_greedy(
            src_mask=src_mask,
            max_output_length=max_output_length,
            model=model,
            encoder_output=encoder_output,
            encoder_hidden=encoder_hidden)

        beam_size = 1
        alpha = 1.0
        output, _ = beam_search(size=beam_size,
                                src_mask=src_mask,
                                n_best=1,
                                max_output_length=max_output_length,
                                model=model,
                                alpha=alpha,
                                encoder_output=encoder_output,
                                encoder_hidden=encoder_hidden)
        np.testing.assert_array_equal(greedy_output, output)