def test_decode_infer(self):
        r"""Tests decoding in inference mode."""
        decoder = BasicRNNDecoder(input_size=self._emb_dim,
                                  vocab_size=self._vocab_size,
                                  hparams=self._hparams)

        decoder.eval()
        start_tokens = torch.tensor([self._vocab_size - 2] * self._batch_size)

        helpers = []
        for strategy in ['infer_greedy', 'infer_sample']:
            helper = decoder.create_helper(
                decoding_strategy=strategy,
                embedding=self._embedding,
                start_tokens=start_tokens,
                end_token=self._vocab_size - 1)
            helpers.append(helper)
        for klass in ['TopKSampleEmbeddingHelper', 'SoftmaxEmbeddingHelper',
                      'GumbelSoftmaxEmbeddingHelper']:
            helper = get_helper(
                klass, embedding=self._embedding,
                start_tokens=start_tokens, end_token=self._vocab_size - 1,
                top_k=self._vocab_size // 2, tau=2.0,
                straight_through=True)
            helpers.append(helper)

        for helper in helpers:
            max_length = 100
            outputs, final_state, sequence_lengths = decoder(
                helper=helper, max_decoding_length=max_length)
            self.assertLessEqual(max(sequence_lengths), max_length)
            self._test_outputs(decoder, outputs, final_state, sequence_lengths,
                               test_mode=True, helper=helper)
Exemple #2
0
    def test_output_layer(self):
        decoder = BasicRNNDecoder(vocab_size=self._vocab_size,
                                  output_layer=None)
        self.assertIsInstance(decoder, BasicRNNDecoder)

        decoder = BasicRNNDecoder(output_layer=tf.identity)
        self.assertIsInstance(decoder, BasicRNNDecoder)

        tensor = tf.random_uniform(
            [self._emb_dim, self._vocab_size], maxval=1, dtype=tf.float32
        )
        decoder = BasicRNNDecoder(output_layer=tensor)
        self.assertIsInstance(decoder, BasicRNNDecoder)
        self.assertEqual(decoder.vocab_size, self._vocab_size)
Exemple #3
0
        def _test_fn(helper):
            _, next_inputs, _ = helper.next_inputs(
                time=1,
                outputs=tf.ones([self._batch_size,
                                 self._vocab_size]),  # Not used
                state=None,  # Not used
                sample_ids=tf.ones([self._batch_size, self._vocab_size]))

            self.assertEqual(helper.sample_ids_shape,
                             tf.TensorShape(self._vocab_size))
            self.assertEqual(next_inputs.get_shape(),
                             tf.TensorShape([self._batch_size, self._emb_dim]))

            # Test in an RNN decoder
            output_layer = tf.layers.Dense(self._vocab_size)
            decoder = BasicRNNDecoder(vocab_size=self._vocab_size,
                                      output_layer=output_layer)
            outputs, final_state, sequence_lengths = decoder(
                helper=helper, max_decoding_length=self._max_seq_length)

            cell_dim = decoder.hparams.rnn_cell.kwargs.num_units
            with self.test_session() as sess:
                sess.run(tf.global_variables_initializer())
                outputs_, final_state_, sequence_lengths_ = sess.run(
                    [outputs, final_state, sequence_lengths])
                max_length = max(sequence_lengths_)
                self.assertEqual(
                    outputs_.logits.shape,
                    (self._batch_size, max_length, self._vocab_size))
                self.assertEqual(
                    outputs_.sample_id.shape,
                    (self._batch_size, max_length, self._vocab_size))
                self.assertEqual(final_state_[0].shape,
                                 (self._batch_size, cell_dim))
    def test_seq_pg_agent(self):
        """Tests logits.
        """
        decoder = BasicRNNDecoder(vocab_size=self._vocab_size)
        outputs, _, sequence_length = decoder(decoding_strategy="infer_greedy",
                                              max_decoding_length=10,
                                              embedding=self._embedding,
                                              start_tokens=[1] *
                                              self._batch_size,
                                              end_token=2)

        agent = SeqPGAgent(outputs.sample_id, outputs.logits, sequence_length,
                           decoder.trainable_variables)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())

            agent.sess = sess

            feed_dict = {context.global_mode(): tf.estimator.ModeKeys.TRAIN}
            for _ in range(2):
                vals = agent.get_samples(feed_dict=feed_dict)
                self.assertEqual(vals['samples'].shape[0], self._batch_size)

                loss_1 = agent.observe([1.] * self._batch_size)
                loss_2 = agent.observe([1.] * self._batch_size,
                                       train_policy=False)
                self.assertEqual(loss_1.shape, ())
                self.assertEqual(loss_2.shape, ())
Exemple #5
0
    def test_decode_train(self):
        """Tests decoding in training mode.
        """
        output_layer = tf.layers.Dense(self._vocab_size)
        decoder = BasicRNNDecoder(vocab_size=self._vocab_size,
                                  output_layer=output_layer)

        helper_train = get_helper(
            decoder.hparams.helper_train.type,
            inputs=self._inputs,
            sequence_length=[self._max_time]*self._batch_size,
            **decoder.hparams.helper_train.kwargs.todict())
        outputs, final_state, sequence_lengths = decoder(helper=helper_train)
        self._test_outputs(decoder, outputs, final_state, sequence_lengths)

        outputs, final_state, sequence_lengths = decoder(
            inputs=self._inputs,
            sequence_length=[self._max_time]*self._batch_size)
        self._test_outputs(decoder, outputs, final_state, sequence_lengths)

        outputs, final_state, sequence_lengths = decoder(
            decoding_strategy=None,
            inputs=self._inputs,
            sequence_length=[self._max_time]*self._batch_size)
        self._test_outputs(decoder, outputs, final_state, sequence_lengths)

        outputs, final_state, sequence_lengths = decoder(
            decoding_strategy=None,
            embedding=self._embedding,
            start_tokens=[1]*self._batch_size,
            end_token=2,
            mode=tf.estimator.ModeKeys.EVAL)
        self._test_outputs(decoder, outputs, final_state, sequence_lengths,
                           test_mode=True)
    def test_decode_train_with_torch(self):
        r"""Compares decoding results with PyTorch built-in decoder.
        """
        decoder = BasicRNNDecoder(input_size=self._emb_dim,
                                  vocab_size=self._vocab_size,
                                  hparams=self._hparams)

        input_size = self._emb_dim
        hidden_size = decoder.hparams.rnn_cell.kwargs.num_units
        num_layers = decoder.hparams.rnn_cell.num_layers
        torch_lstm = nn.LSTM(input_size, hidden_size, num_layers,
                             batch_first=True)

        # match parameters
        for name in ['weight_ih', 'weight_hh', 'bias_ih', 'bias_hh']:
            setattr(torch_lstm, f'{name}_l0',
                    getattr(decoder._cell._cell, name))
        torch_lstm.flatten_parameters()

        output_layer = decoder._output_layer
        input_lengths = torch.tensor([self._max_time] * self._batch_size)
        embedding = torch.randn(self._vocab_size, self._emb_dim)
        inputs = torch.randint(
            self._vocab_size, size=(self._batch_size, self._max_time))

        # decoder outputs
        helper_train = decoder.create_helper(embedding=embedding)
        outputs, final_state, sequence_lengths = decoder(
            inputs=inputs,
            sequence_length=input_lengths,
            helper=helper_train)

        # torch LSTM outputs
        lstm_inputs = F.embedding(inputs, embedding)
        torch_outputs, torch_states = torch_lstm(lstm_inputs)
        torch_outputs = output_layer(torch_outputs)
        torch_sample_id = torch.argmax(torch_outputs, dim=-1)

        self.assertEqual(final_state[0].shape,
                         (self._batch_size, hidden_size))

        self._assert_tensor_equal(outputs.logits, torch_outputs)
        self._assert_tensor_equal(outputs.sample_id, torch_sample_id)
        self._assert_tensor_equal(final_state[0], torch_states[0].squeeze(0))
        self._assert_tensor_equal(final_state[1], torch_states[1].squeeze(0))
        self._assert_tensor_equal(sequence_lengths, input_lengths)
 def setUp(self):
     self._vocab_size = 4
     self._max_time = 8
     self._batch_size = 16
     self._emb_dim = 20
     self._inputs = torch.rand(
         self._batch_size, self._max_time, self._emb_dim,
         dtype=torch.float)
     self._embedding = torch.rand(
         self._vocab_size, self._emb_dim, dtype=torch.float)
     self._hparams = HParams(None, BasicRNNDecoder.default_hparams())
Exemple #8
0
    def test_decode_train(self):
        r"""Tests decoding in training mode.
        """
        decoder = BasicRNNDecoder(token_embedder=self._embedder,
                                  input_size=self._emb_dim,
                                  vocab_size=self._vocab_size,
                                  hparams=self._hparams)
        sequence_length = torch.tensor([self._max_time] * self._batch_size)

        # Helper by default HParams
        helper_train = decoder.create_helper()
        outputs, final_state, sequence_lengths = decoder(
            helper=helper_train,
            inputs=self._inputs,
            sequence_length=sequence_length)
        self._test_outputs(decoder, outputs, final_state, sequence_lengths)

        # Helper by decoding strategy
        helper_train = decoder.create_helper(decoding_strategy='train_greedy')
        outputs, final_state, sequence_lengths = decoder(
            helper=helper_train,
            inputs=self._inputs,
            sequence_length=sequence_length)
        self._test_outputs(decoder, outputs, final_state, sequence_lengths)

        # Implicit helper
        outputs, final_state, sequence_lengths = decoder(
            inputs=self._inputs, sequence_length=sequence_length)
        self._test_outputs(decoder, outputs, final_state, sequence_lengths)

        # Eval helper through forward args
        outputs, final_state, sequence_lengths = decoder(
            embedding=self._embedder,
            start_tokens=torch.tensor([1] * self._batch_size),
            end_token=2,
            infer_mode=True)
        self._test_outputs(decoder,
                           outputs,
                           final_state,
                           sequence_lengths,
                           test_mode=True)
Exemple #9
0
    def test_decode_infer(self):
        """Tests decoding in inferencee mode.
        """
        output_layer = tf.layers.Dense(self._vocab_size)
        decoder = BasicRNNDecoder(vocab_size=self._vocab_size,
                                  output_layer=output_layer)

        helper_infer = get_helper(
            decoder.hparams.helper_infer.type,
            embedding=self._embedding,
            start_tokens=[self._vocab_size - 2] * self._batch_size,
            end_token=self._vocab_size - 1,
            **decoder.hparams.helper_train.kwargs.todict())

        outputs, final_state, sequence_lengths = decoder(helper=helper_infer)

        # 4 trainable variables: embedding, cell-kernel, cell-bias,
        # fc-layer-weights, fc-layer-bias
        self.assertEqual(len(decoder.trainable_variables), 4)

        cell_dim = decoder.hparams.rnn_cell.kwargs.num_units
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            outputs_, final_state_, sequence_lengths_ = sess.run(
                [outputs, final_state, sequence_lengths],
                feed_dict={
                    context.global_mode(): tf.estimator.ModeKeys.PREDICT
                })
            self.assertIsInstance(outputs_, BasicRNNDecoderOutput)
            max_length = max(sequence_lengths_)
            self.assertEqual(outputs_.logits.shape,
                             (self._batch_size, max_length, self._vocab_size))
            self.assertEqual(outputs_.sample_id.shape,
                             (self._batch_size, max_length))
            self.assertEqual(final_state_[0].shape,
                             (self._batch_size, cell_dim))
Exemple #10
0
    def test_decode_train_with_tf(self):
        """Compares decoding results with TF built-in decoder.
        """
        _inputs_placeholder = tf.placeholder(
            tf.int32, [self._batch_size, self._max_time], name="inputs")
        _embedding_placeholder = tf.placeholder(
            tf.float32, [self._vocab_size, self._emb_dim], name="emb")
        inputs = tf.nn.embedding_lookup(_embedding_placeholder,
                                        _inputs_placeholder)

        output_layer = tf.layers.Dense(self._vocab_size)
        decoder = BasicRNNDecoder(vocab_size=self._vocab_size,
                                  output_layer=output_layer)

        helper_train = get_helper(
            decoder.hparams.helper_train.type,
            inputs=inputs,
            sequence_length=[self._max_time] * self._batch_size,
            **decoder.hparams.helper_train.kwargs.todict())

        outputs, final_state, sequence_lengths = decoder(helper=helper_train)

        tf_helper = tf.contrib.seq2seq.TrainingHelper(
            inputs, [self._max_time] * self._batch_size)

        tf_decoder = tf.contrib.seq2seq.BasicDecoder(decoder.cell,
                                                     tf_helper,
                                                     decoder.cell.zero_state(
                                                         self._batch_size,
                                                         tf.float32),
                                                     output_layer=output_layer)

        tf_outputs, tf_final_state, tf_sequence_lengths = \
            tf.contrib.seq2seq.dynamic_decode(tf_decoder)

        cell_dim = decoder.hparams.rnn_cell.kwargs.num_units
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            inputs_ = np.random.randint(self._vocab_size,
                                        size=(self._batch_size,
                                              self._max_time),
                                        dtype=np.int32)
            embedding_ = np.random.randn(self._vocab_size, self._emb_dim)

            outputs_, final_state_, sequence_lengths_ = sess.run(
                [outputs, final_state, sequence_lengths],
                feed_dict={
                    context.global_mode(): tf.estimator.ModeKeys.TRAIN,
                    _inputs_placeholder: inputs_,
                    _embedding_placeholder: embedding_
                })
            self.assertEqual(final_state_[0].shape,
                             (self._batch_size, cell_dim))

            tf_outputs_, tf_final_state_, tf_sequence_lengths_ = sess.run(
                [tf_outputs, tf_final_state, tf_sequence_lengths],
                feed_dict={
                    context.global_mode(): tf.estimator.ModeKeys.TRAIN,
                    _inputs_placeholder: inputs_,
                    _embedding_placeholder: embedding_
                })

            np.testing.assert_array_equal(outputs_.logits,
                                          tf_outputs_.rnn_output)
            np.testing.assert_array_equal(outputs_.sample_id,
                                          tf_outputs_.sample_id)
            np.testing.assert_array_equal(final_state_.c, tf_final_state_.c)
            np.testing.assert_array_equal(final_state_.h, tf_final_state_.h)
            np.testing.assert_array_equal(sequence_lengths_,
                                          tf_sequence_lengths_)