def test_multicell_decode_infer(self): r"""Tests decoding in inference mode. """ seq_length = np.random.randint(self._max_time, size=[self._batch_size ]) + 1 encoder_values_length = torch.tensor(seq_length) decoder = AttentionRNNDecoder(encoder_output_size=64, vocab_size=self._vocab_size, input_size=self._emb_dim + 64, hparams=self._hparams_multicell) decoder.eval() helper_infer = decoder.create_helper(embedding=self._embedding, start_tokens=torch.tensor( [1] * self._batch_size), end_token=2) outputs, final_state, sequence_lengths = decoder( memory=self._encoder_output, memory_sequence_length=encoder_values_length, helper=helper_infer) self.assertEqual(len(decoder.trainable_variables), 15) self._test_outputs(decoder, outputs, final_state, sequence_lengths, test_mode=True, is_multi=True)
def test_decode_infer(self): r"""Tests decoding in inference mode. """ seq_length = np.random.randint(self._max_time, size=[self._batch_size ]) + 1 encoder_values_length = torch.tensor(seq_length) for (cell_type, is_multi), hparams in self._test_hparams.items(): decoder = AttentionRNNDecoder(encoder_output_size=64, token_embedder=self._embedder, vocab_size=self._vocab_size, input_size=self._emb_dim, hparams=hparams) decoder.eval() helper_infer = decoder.create_helper(start_tokens=torch.tensor( [1] * self._batch_size), end_token=2) outputs, final_state, sequence_lengths = decoder( memory=self._encoder_output, memory_sequence_length=encoder_values_length, helper=helper_infer) self._test_outputs(decoder, outputs, final_state, sequence_lengths, test_mode=True)