def test_decode_infer(self): r"""Tests decoding in inference mode.""" decoder = BasicRNNDecoder(input_size=self._emb_dim, vocab_size=self._vocab_size, hparams=self._hparams) decoder.eval() start_tokens = torch.tensor([self._vocab_size - 2] * self._batch_size) helpers = [] for strategy in ['infer_greedy', 'infer_sample']: helper = decoder.create_helper( decoding_strategy=strategy, embedding=self._embedding, start_tokens=start_tokens, end_token=self._vocab_size - 1) helpers.append(helper) for klass in ['TopKSampleEmbeddingHelper', 'SoftmaxEmbeddingHelper', 'GumbelSoftmaxEmbeddingHelper']: helper = get_helper( klass, embedding=self._embedding, start_tokens=start_tokens, end_token=self._vocab_size - 1, top_k=self._vocab_size // 2, tau=2.0, straight_through=True) helpers.append(helper) for helper in helpers: max_length = 100 outputs, final_state, sequence_lengths = decoder( helper=helper, max_decoding_length=max_length) self.assertLessEqual(max(sequence_lengths), max_length) self._test_outputs(decoder, outputs, final_state, sequence_lengths, test_mode=True, helper=helper)
def test_output_layer(self): decoder = BasicRNNDecoder(vocab_size=self._vocab_size, output_layer=None) self.assertIsInstance(decoder, BasicRNNDecoder) decoder = BasicRNNDecoder(output_layer=tf.identity) self.assertIsInstance(decoder, BasicRNNDecoder) tensor = tf.random_uniform( [self._emb_dim, self._vocab_size], maxval=1, dtype=tf.float32 ) decoder = BasicRNNDecoder(output_layer=tensor) self.assertIsInstance(decoder, BasicRNNDecoder) self.assertEqual(decoder.vocab_size, self._vocab_size)
def _test_fn(helper): _, next_inputs, _ = helper.next_inputs( time=1, outputs=tf.ones([self._batch_size, self._vocab_size]), # Not used state=None, # Not used sample_ids=tf.ones([self._batch_size, self._vocab_size])) self.assertEqual(helper.sample_ids_shape, tf.TensorShape(self._vocab_size)) self.assertEqual(next_inputs.get_shape(), tf.TensorShape([self._batch_size, self._emb_dim])) # Test in an RNN decoder output_layer = tf.layers.Dense(self._vocab_size) decoder = BasicRNNDecoder(vocab_size=self._vocab_size, output_layer=output_layer) outputs, final_state, sequence_lengths = decoder( helper=helper, max_decoding_length=self._max_seq_length) cell_dim = decoder.hparams.rnn_cell.kwargs.num_units with self.test_session() as sess: sess.run(tf.global_variables_initializer()) outputs_, final_state_, sequence_lengths_ = sess.run( [outputs, final_state, sequence_lengths]) max_length = max(sequence_lengths_) self.assertEqual( outputs_.logits.shape, (self._batch_size, max_length, self._vocab_size)) self.assertEqual( outputs_.sample_id.shape, (self._batch_size, max_length, self._vocab_size)) self.assertEqual(final_state_[0].shape, (self._batch_size, cell_dim))
def test_seq_pg_agent(self): """Tests logits. """ decoder = BasicRNNDecoder(vocab_size=self._vocab_size) outputs, _, sequence_length = decoder(decoding_strategy="infer_greedy", max_decoding_length=10, embedding=self._embedding, start_tokens=[1] * self._batch_size, end_token=2) agent = SeqPGAgent(outputs.sample_id, outputs.logits, sequence_length, decoder.trainable_variables) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) agent.sess = sess feed_dict = {context.global_mode(): tf.estimator.ModeKeys.TRAIN} for _ in range(2): vals = agent.get_samples(feed_dict=feed_dict) self.assertEqual(vals['samples'].shape[0], self._batch_size) loss_1 = agent.observe([1.] * self._batch_size) loss_2 = agent.observe([1.] * self._batch_size, train_policy=False) self.assertEqual(loss_1.shape, ()) self.assertEqual(loss_2.shape, ())
def test_decode_train(self): """Tests decoding in training mode. """ output_layer = tf.layers.Dense(self._vocab_size) decoder = BasicRNNDecoder(vocab_size=self._vocab_size, output_layer=output_layer) helper_train = get_helper( decoder.hparams.helper_train.type, inputs=self._inputs, sequence_length=[self._max_time]*self._batch_size, **decoder.hparams.helper_train.kwargs.todict()) outputs, final_state, sequence_lengths = decoder(helper=helper_train) self._test_outputs(decoder, outputs, final_state, sequence_lengths) outputs, final_state, sequence_lengths = decoder( inputs=self._inputs, sequence_length=[self._max_time]*self._batch_size) self._test_outputs(decoder, outputs, final_state, sequence_lengths) outputs, final_state, sequence_lengths = decoder( decoding_strategy=None, inputs=self._inputs, sequence_length=[self._max_time]*self._batch_size) self._test_outputs(decoder, outputs, final_state, sequence_lengths) outputs, final_state, sequence_lengths = decoder( decoding_strategy=None, embedding=self._embedding, start_tokens=[1]*self._batch_size, end_token=2, mode=tf.estimator.ModeKeys.EVAL) self._test_outputs(decoder, outputs, final_state, sequence_lengths, test_mode=True)
def test_decode_train_with_torch(self): r"""Compares decoding results with PyTorch built-in decoder. """ decoder = BasicRNNDecoder(input_size=self._emb_dim, vocab_size=self._vocab_size, hparams=self._hparams) input_size = self._emb_dim hidden_size = decoder.hparams.rnn_cell.kwargs.num_units num_layers = decoder.hparams.rnn_cell.num_layers torch_lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) # match parameters for name in ['weight_ih', 'weight_hh', 'bias_ih', 'bias_hh']: setattr(torch_lstm, f'{name}_l0', getattr(decoder._cell._cell, name)) torch_lstm.flatten_parameters() output_layer = decoder._output_layer input_lengths = torch.tensor([self._max_time] * self._batch_size) embedding = torch.randn(self._vocab_size, self._emb_dim) inputs = torch.randint( self._vocab_size, size=(self._batch_size, self._max_time)) # decoder outputs helper_train = decoder.create_helper(embedding=embedding) outputs, final_state, sequence_lengths = decoder( inputs=inputs, sequence_length=input_lengths, helper=helper_train) # torch LSTM outputs lstm_inputs = F.embedding(inputs, embedding) torch_outputs, torch_states = torch_lstm(lstm_inputs) torch_outputs = output_layer(torch_outputs) torch_sample_id = torch.argmax(torch_outputs, dim=-1) self.assertEqual(final_state[0].shape, (self._batch_size, hidden_size)) self._assert_tensor_equal(outputs.logits, torch_outputs) self._assert_tensor_equal(outputs.sample_id, torch_sample_id) self._assert_tensor_equal(final_state[0], torch_states[0].squeeze(0)) self._assert_tensor_equal(final_state[1], torch_states[1].squeeze(0)) self._assert_tensor_equal(sequence_lengths, input_lengths)
def setUp(self): self._vocab_size = 4 self._max_time = 8 self._batch_size = 16 self._emb_dim = 20 self._inputs = torch.rand( self._batch_size, self._max_time, self._emb_dim, dtype=torch.float) self._embedding = torch.rand( self._vocab_size, self._emb_dim, dtype=torch.float) self._hparams = HParams(None, BasicRNNDecoder.default_hparams())
def test_decode_train(self): r"""Tests decoding in training mode. """ decoder = BasicRNNDecoder(token_embedder=self._embedder, input_size=self._emb_dim, vocab_size=self._vocab_size, hparams=self._hparams) sequence_length = torch.tensor([self._max_time] * self._batch_size) # Helper by default HParams helper_train = decoder.create_helper() outputs, final_state, sequence_lengths = decoder( helper=helper_train, inputs=self._inputs, sequence_length=sequence_length) self._test_outputs(decoder, outputs, final_state, sequence_lengths) # Helper by decoding strategy helper_train = decoder.create_helper(decoding_strategy='train_greedy') outputs, final_state, sequence_lengths = decoder( helper=helper_train, inputs=self._inputs, sequence_length=sequence_length) self._test_outputs(decoder, outputs, final_state, sequence_lengths) # Implicit helper outputs, final_state, sequence_lengths = decoder( inputs=self._inputs, sequence_length=sequence_length) self._test_outputs(decoder, outputs, final_state, sequence_lengths) # Eval helper through forward args outputs, final_state, sequence_lengths = decoder( embedding=self._embedder, start_tokens=torch.tensor([1] * self._batch_size), end_token=2, infer_mode=True) self._test_outputs(decoder, outputs, final_state, sequence_lengths, test_mode=True)
def test_decode_infer(self): """Tests decoding in inferencee mode. """ output_layer = tf.layers.Dense(self._vocab_size) decoder = BasicRNNDecoder(vocab_size=self._vocab_size, output_layer=output_layer) helper_infer = get_helper( decoder.hparams.helper_infer.type, embedding=self._embedding, start_tokens=[self._vocab_size - 2] * self._batch_size, end_token=self._vocab_size - 1, **decoder.hparams.helper_train.kwargs.todict()) outputs, final_state, sequence_lengths = decoder(helper=helper_infer) # 4 trainable variables: embedding, cell-kernel, cell-bias, # fc-layer-weights, fc-layer-bias self.assertEqual(len(decoder.trainable_variables), 4) cell_dim = decoder.hparams.rnn_cell.kwargs.num_units with self.test_session() as sess: sess.run(tf.global_variables_initializer()) outputs_, final_state_, sequence_lengths_ = sess.run( [outputs, final_state, sequence_lengths], feed_dict={ context.global_mode(): tf.estimator.ModeKeys.PREDICT }) self.assertIsInstance(outputs_, BasicRNNDecoderOutput) max_length = max(sequence_lengths_) self.assertEqual(outputs_.logits.shape, (self._batch_size, max_length, self._vocab_size)) self.assertEqual(outputs_.sample_id.shape, (self._batch_size, max_length)) self.assertEqual(final_state_[0].shape, (self._batch_size, cell_dim))
def test_decode_train_with_tf(self): """Compares decoding results with TF built-in decoder. """ _inputs_placeholder = tf.placeholder( tf.int32, [self._batch_size, self._max_time], name="inputs") _embedding_placeholder = tf.placeholder( tf.float32, [self._vocab_size, self._emb_dim], name="emb") inputs = tf.nn.embedding_lookup(_embedding_placeholder, _inputs_placeholder) output_layer = tf.layers.Dense(self._vocab_size) decoder = BasicRNNDecoder(vocab_size=self._vocab_size, output_layer=output_layer) helper_train = get_helper( decoder.hparams.helper_train.type, inputs=inputs, sequence_length=[self._max_time] * self._batch_size, **decoder.hparams.helper_train.kwargs.todict()) outputs, final_state, sequence_lengths = decoder(helper=helper_train) tf_helper = tf.contrib.seq2seq.TrainingHelper( inputs, [self._max_time] * self._batch_size) tf_decoder = tf.contrib.seq2seq.BasicDecoder(decoder.cell, tf_helper, decoder.cell.zero_state( self._batch_size, tf.float32), output_layer=output_layer) tf_outputs, tf_final_state, tf_sequence_lengths = \ tf.contrib.seq2seq.dynamic_decode(tf_decoder) cell_dim = decoder.hparams.rnn_cell.kwargs.num_units with self.test_session() as sess: sess.run(tf.global_variables_initializer()) inputs_ = np.random.randint(self._vocab_size, size=(self._batch_size, self._max_time), dtype=np.int32) embedding_ = np.random.randn(self._vocab_size, self._emb_dim) outputs_, final_state_, sequence_lengths_ = sess.run( [outputs, final_state, sequence_lengths], feed_dict={ context.global_mode(): tf.estimator.ModeKeys.TRAIN, _inputs_placeholder: inputs_, _embedding_placeholder: embedding_ }) self.assertEqual(final_state_[0].shape, (self._batch_size, cell_dim)) tf_outputs_, tf_final_state_, tf_sequence_lengths_ = sess.run( [tf_outputs, tf_final_state, tf_sequence_lengths], feed_dict={ context.global_mode(): tf.estimator.ModeKeys.TRAIN, _inputs_placeholder: inputs_, _embedding_placeholder: embedding_ }) np.testing.assert_array_equal(outputs_.logits, tf_outputs_.rnn_output) np.testing.assert_array_equal(outputs_.sample_id, tf_outputs_.sample_id) np.testing.assert_array_equal(final_state_.c, tf_final_state_.c) np.testing.assert_array_equal(final_state_.h, tf_final_state_.h) np.testing.assert_array_equal(sequence_lengths_, tf_sequence_lengths_)