Пример #1
0
    def test_multicell_decode_infer(self):
        r"""Tests decoding in inference mode.
        """
        seq_length = np.random.randint(self._max_time, size=[self._batch_size
                                                             ]) + 1
        encoder_values_length = torch.tensor(seq_length)

        decoder = AttentionRNNDecoder(encoder_output_size=64,
                                      vocab_size=self._vocab_size,
                                      input_size=self._emb_dim + 64,
                                      hparams=self._hparams_multicell)

        decoder.eval()

        helper_infer = decoder.create_helper(embedding=self._embedding,
                                             start_tokens=torch.tensor(
                                                 [1] * self._batch_size),
                                             end_token=2)

        outputs, final_state, sequence_lengths = decoder(
            memory=self._encoder_output,
            memory_sequence_length=encoder_values_length,
            helper=helper_infer)

        self.assertEqual(len(decoder.trainable_variables), 15)

        self._test_outputs(decoder,
                           outputs,
                           final_state,
                           sequence_lengths,
                           test_mode=True,
                           is_multi=True)
Пример #2
0
    def test_decode_infer(self):
        r"""Tests decoding in inference mode.
        """
        seq_length = np.random.randint(self._max_time, size=[self._batch_size
                                                             ]) + 1
        encoder_values_length = torch.tensor(seq_length)

        for (cell_type, is_multi), hparams in self._test_hparams.items():
            decoder = AttentionRNNDecoder(encoder_output_size=64,
                                          token_embedder=self._embedder,
                                          vocab_size=self._vocab_size,
                                          input_size=self._emb_dim,
                                          hparams=hparams)

            decoder.eval()

            helper_infer = decoder.create_helper(start_tokens=torch.tensor(
                [1] * self._batch_size),
                                                 end_token=2)

            outputs, final_state, sequence_lengths = decoder(
                memory=self._encoder_output,
                memory_sequence_length=encoder_values_length,
                helper=helper_infer)

            self._test_outputs(decoder,
                               outputs,
                               final_state,
                               sequence_lengths,
                               test_mode=True)