def test_decode_train(self): r"""Tests decoding in training mode. """ decoder = BasicRNNDecoder( token_embedder=self._embedder, input_size=self._emb_dim, vocab_size=self._vocab_size, hparams=self._hparams) sequence_length = torch.tensor([self._max_time] * self._batch_size) # Helper by default HParams helper_train = decoder.create_helper() outputs, final_state, sequence_lengths = decoder( helper=helper_train, inputs=self._inputs, sequence_length=sequence_length) self._test_outputs(decoder, outputs, final_state, sequence_lengths) # Helper by decoding strategy helper_train = decoder.create_helper(decoding_strategy='train_greedy') outputs, final_state, sequence_lengths = decoder( helper=helper_train, inputs=self._inputs, sequence_length=sequence_length) self._test_outputs(decoder, outputs, final_state, sequence_lengths) # Implicit helper outputs, final_state, sequence_lengths = decoder( inputs=self._inputs, sequence_length=sequence_length) self._test_outputs(decoder, outputs, final_state, sequence_lengths) # Eval helper through forward args outputs, final_state, sequence_lengths = decoder( embedding=self._embedder, start_tokens=torch.tensor([1] * self._batch_size), end_token=2, infer_mode=True) self._test_outputs( decoder, outputs, final_state, sequence_lengths, test_mode=True)
def test_decode_infer(self): r"""Tests decoding in inference mode.""" decoder = BasicRNNDecoder( token_embedder=self._embedder, input_size=self._emb_dim, vocab_size=self._vocab_size, hparams=self._hparams) decoder.eval() start_tokens = torch.tensor([self._vocab_size - 2] * self._batch_size) helpers = [] for strategy in ['infer_greedy', 'infer_sample']: helper = decoder.create_helper( decoding_strategy=strategy, start_tokens=start_tokens, end_token=self._vocab_size - 1) helpers.append(helper) for klass in ['TopKSampleEmbeddingHelper', 'SoftmaxEmbeddingHelper', 'GumbelSoftmaxEmbeddingHelper']: helper = get_helper( klass, start_tokens=start_tokens, end_token=self._vocab_size - 1, top_k=self._vocab_size // 2, tau=2.0, straight_through=True) helpers.append(helper) for helper in helpers: max_length = 100 outputs, final_state, sequence_lengths = decoder( helper=helper, max_decoding_length=max_length) self.assertLessEqual(max(sequence_lengths), max_length) self._test_outputs(decoder, outputs, final_state, sequence_lengths, test_mode=True)
def test_decode_train_with_torch(self): r"""Compares decoding results with PyTorch built-in decoder. """ decoder = BasicRNNDecoder( token_embedder=self._embedder, input_size=self._emb_dim, vocab_size=self._vocab_size, hparams=self._hparams) input_size = self._emb_dim hidden_size = decoder.hparams.rnn_cell.kwargs.num_units num_layers = decoder.hparams.rnn_cell.num_layers torch_lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) # match parameters for name in ['weight_ih', 'weight_hh', 'bias_ih', 'bias_hh']: setattr(torch_lstm, f'{name}_l0', getattr(decoder._cell._cell, name)) torch_lstm.flatten_parameters() output_layer = decoder._output_layer input_lengths = torch.tensor([self._max_time] * self._batch_size) inputs = torch.randint( self._vocab_size, size=(self._batch_size, self._max_time)) # decoder outputs helper_train = decoder.create_helper() outputs, final_state, sequence_lengths = decoder( inputs=inputs, sequence_length=input_lengths, helper=helper_train) # torch LSTM outputs lstm_inputs = F.embedding(inputs, self._embedder.embedding) torch_outputs, torch_states = torch_lstm(lstm_inputs) torch_outputs = output_layer(torch_outputs) torch_sample_id = torch.argmax(torch_outputs, dim=-1) self.assertEqual(final_state[0].shape, (self._batch_size, hidden_size)) self._assert_tensor_equal(outputs.logits, torch_outputs) self._assert_tensor_equal(outputs.sample_id, torch_sample_id) self._assert_tensor_equal(final_state[0], torch_states[0].squeeze(0)) self._assert_tensor_equal(final_state[1], torch_states[1].squeeze(0)) self._assert_tensor_equal(sequence_lengths, input_lengths)
def setUp(self): self._vocab_size = 4 self._max_time = 8 self._batch_size = 16 self._emb_dim = 20 self._inputs = torch.randint( self._vocab_size, size=(self._batch_size, self._max_time)) embedding = torch.rand( self._vocab_size, self._emb_dim, dtype=torch.float) self._embedder = WordEmbedder(init_value=embedding) self._hparams = HParams(None, BasicRNNDecoder.default_hparams())