Пример #1
0
    def test_output_layer(self):
        decoder = BasicRNNDecoder(vocab_size=self._vocab_size,
                                  output_layer=None)
        self.assertIsInstance(decoder, BasicRNNDecoder)

        decoder = BasicRNNDecoder(output_layer=tf.identity)
        self.assertIsInstance(decoder, BasicRNNDecoder)

        tensor = tf.random_uniform([self._emb_dim, self._vocab_size],
                                   maxval=1,
                                   dtype=tf.float32)
        decoder = BasicRNNDecoder(output_layer=tensor)
        self.assertIsInstance(decoder, BasicRNNDecoder)
        self.assertEqual(decoder.vocab_size, self._vocab_size)
Пример #2
0
    def test_seq_pg_agent(self):
        """Tests logits.
        """
        decoder = BasicRNNDecoder(vocab_size=self._vocab_size)
        outputs, _, sequence_length = decoder(decoding_strategy="infer_greedy",
                                              max_decoding_length=10,
                                              embedding=self._embedding,
                                              start_tokens=[1] *
                                              self._batch_size,
                                              end_token=2)

        agent = SeqPGAgent(outputs.sample_id, outputs.logits, sequence_length,
                           decoder.trainable_variables)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())

            agent.sess = sess

            feed_dict = {context.global_mode(): tf.estimator.ModeKeys.TRAIN}
            for _ in range(2):
                vals = agent.get_samples(feed_dict=feed_dict)
                self.assertEqual(vals['samples'].shape[0], self._batch_size)

                loss_1 = agent.observe([1.] * self._batch_size)
                loss_2 = agent.observe([1.] * self._batch_size,
                                       train_policy=False)
                self.assertEqual(loss_1.shape, ())
                self.assertEqual(loss_2.shape, ())
Пример #3
0
        def _test_fn(helper):
            _, next_inputs, _ = helper.next_inputs(
                time=1,
                outputs=tf.ones([self._batch_size, self._vocab_size]),# Not used
                state=None, # Not used
                sample_ids=tf.ones([self._batch_size, self._vocab_size]))

            self.assertEqual(helper.sample_ids_shape,
                             tf.TensorShape(self._vocab_size))
            self.assertEqual(next_inputs.get_shape(),
                             tf.TensorShape([self._batch_size, self._emb_dim]))

            # Test in an RNN decoder
            output_layer = tf.layers.Dense(self._vocab_size)
            decoder = BasicRNNDecoder(vocab_size=self._vocab_size,
                                      output_layer=output_layer)
            outputs, final_state, sequence_lengths = decoder(
                helper=helper, max_decoding_length=self._max_seq_length)

            cell_dim = decoder.hparams.rnn_cell.kwargs.num_units
            with self.test_session() as sess:
                sess.run(tf.global_variables_initializer())
                outputs_, final_state_, sequence_lengths_ = sess.run(
                    [outputs, final_state, sequence_lengths])
                max_length = max(sequence_lengths_)
                self.assertEqual(
                    outputs_.logits.shape,
                    (self._batch_size, max_length, self._vocab_size))
                self.assertEqual(
                    outputs_.sample_id.shape,
                    (self._batch_size, max_length, self._vocab_size))
                self.assertEqual(final_state_[0].shape,
                                 (self._batch_size, cell_dim))
Пример #4
0
    def test_decode_train(self):
        """Tests decoding in training mode.
        """
        output_layer = tf.layers.Dense(self._vocab_size)
        decoder = BasicRNNDecoder(vocab_size=self._vocab_size,
                                  output_layer=output_layer)

        helper_train = get_helper(
            decoder.hparams.helper_train.type,
            inputs=self._inputs,
            sequence_length=[self._max_time] * self._batch_size,
            **decoder.hparams.helper_train.kwargs.todict())
        outputs, final_state, sequence_lengths = decoder(helper=helper_train)
        self._test_outputs(decoder, outputs, final_state, sequence_lengths)

        outputs, final_state, sequence_lengths = decoder(
            inputs=self._inputs,
            sequence_length=[self._max_time] * self._batch_size)
        self._test_outputs(decoder, outputs, final_state, sequence_lengths)

        outputs, final_state, sequence_lengths = decoder(
            decoding_strategy=None,
            inputs=self._inputs,
            sequence_length=[self._max_time] * self._batch_size)
        self._test_outputs(decoder, outputs, final_state, sequence_lengths)

        outputs, final_state, sequence_lengths = decoder(
            decoding_strategy=None,
            embedding=self._embedding,
            start_tokens=[1] * self._batch_size,
            end_token=2,
            mode=tf.estimator.ModeKeys.EVAL)
        self._test_outputs(decoder,
                           outputs,
                           final_state,
                           sequence_lengths,
                           test_mode=True)
Пример #5
0
    def test_decode_infer(self):
        """Tests decoding in inferencee mode.
        """
        output_layer = tf.layers.Dense(self._vocab_size)
        decoder = BasicRNNDecoder(vocab_size=self._vocab_size,
                                  output_layer=output_layer)

        helper_infer = get_helper(
            decoder.hparams.helper_infer.type,
            embedding=self._embedding,
            start_tokens=[self._vocab_size - 2] * self._batch_size,
            end_token=self._vocab_size - 1,
            **decoder.hparams.helper_train.kwargs.todict())

        outputs, final_state, sequence_lengths = decoder(helper=helper_infer)

        # 4 trainable variables: embedding, cell-kernel, cell-bias,
        # fc-layer-weights, fc-layer-bias
        self.assertEqual(len(decoder.trainable_variables), 4)

        cell_dim = decoder.hparams.rnn_cell.kwargs.num_units
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            outputs_, final_state_, sequence_lengths_ = sess.run(
                [outputs, final_state, sequence_lengths],
                feed_dict={
                    context.global_mode(): tf.estimator.ModeKeys.PREDICT
                })
            self.assertIsInstance(outputs_, BasicRNNDecoderOutput)
            max_length = max(sequence_lengths_)
            self.assertEqual(outputs_.logits.shape,
                             (self._batch_size, max_length, self._vocab_size))
            self.assertEqual(outputs_.sample_id.shape,
                             (self._batch_size, max_length))
            self.assertEqual(final_state_[0].shape,
                             (self._batch_size, cell_dim))
Пример #6
0
    def test_decode_train_with_tf(self):
        """Compares decoding results with TF built-in decoder.
        """
        _inputs_placeholder = tf.placeholder(
            tf.int32, [self._batch_size, self._max_time], name="inputs")
        _embedding_placeholder = tf.placeholder(
            tf.float32, [self._vocab_size, self._emb_dim], name="emb")
        inputs = tf.nn.embedding_lookup(_embedding_placeholder,
                                        _inputs_placeholder)

        output_layer = tf.layers.Dense(self._vocab_size)
        decoder = BasicRNNDecoder(vocab_size=self._vocab_size,
                                  output_layer=output_layer)

        helper_train = get_helper(
            decoder.hparams.helper_train.type,
            inputs=inputs,
            sequence_length=[self._max_time] * self._batch_size,
            **decoder.hparams.helper_train.kwargs.todict())

        outputs, final_state, sequence_lengths = decoder(helper=helper_train)

        tf_helper = tf.contrib.seq2seq.TrainingHelper(
            inputs, [self._max_time] * self._batch_size)

        tf_decoder = tf.contrib.seq2seq.BasicDecoder(decoder.cell,
                                                     tf_helper,
                                                     decoder.cell.zero_state(
                                                         self._batch_size,
                                                         tf.float32),
                                                     output_layer=output_layer)

        tf_outputs, tf_final_state, tf_sequence_lengths = \
            tf.contrib.seq2seq.dynamic_decode(tf_decoder)

        cell_dim = decoder.hparams.rnn_cell.kwargs.num_units
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            inputs_ = np.random.randint(self._vocab_size,
                                        size=(self._batch_size,
                                              self._max_time),
                                        dtype=np.int32)
            embedding_ = np.random.randn(self._vocab_size, self._emb_dim)

            outputs_, final_state_, sequence_lengths_ = sess.run(
                [outputs, final_state, sequence_lengths],
                feed_dict={
                    context.global_mode(): tf.estimator.ModeKeys.TRAIN,
                    _inputs_placeholder: inputs_,
                    _embedding_placeholder: embedding_
                })
            self.assertEqual(final_state_[0].shape,
                             (self._batch_size, cell_dim))

            tf_outputs_, tf_final_state_, tf_sequence_lengths_ = sess.run(
                [tf_outputs, tf_final_state, tf_sequence_lengths],
                feed_dict={
                    context.global_mode(): tf.estimator.ModeKeys.TRAIN,
                    _inputs_placeholder: inputs_,
                    _embedding_placeholder: embedding_
                })

            np.testing.assert_array_equal(outputs_.logits,
                                          tf_outputs_.rnn_output)
            np.testing.assert_array_equal(outputs_.sample_id,
                                          tf_outputs_.sample_id)
            np.testing.assert_array_equal(final_state_.c, tf_final_state_.c)
            np.testing.assert_array_equal(final_state_.h, tf_final_state_.h)
            np.testing.assert_array_equal(sequence_lengths_,
                                          tf_sequence_lengths_)