示例#1
0
    def testStepWithInferenceHelperMultilabel(self):
        batch_size = 5
        vocabulary_size = 7
        cell_depth = vocabulary_size
        start_token = 0
        end_token = 6

        start_inputs = tf.one_hot(
            np.ones(batch_size) * start_token, vocabulary_size)

        # The sample function samples independent bernoullis from the logits.
        sample_fn = (
            lambda x: seq2seq.bernoulli_sample(logits=x, dtype=tf.bool))
        # The next inputs are a one-hot encoding of the sampled labels.
        next_inputs_fn = tf.to_float
        end_fn = lambda sample_ids: sample_ids[:, end_token]

        with self.session(use_gpu=True) as sess:
            with tf.variable_scope("testStepWithInferenceHelper",
                                   initializer=tf.constant_initializer(0.01)):
                cell = tf.nn.rnn_cell.LSTMCell(vocabulary_size)
                helper = seq2seq.InferenceHelper(sample_fn,
                                                 sample_shape=[cell_depth],
                                                 sample_dtype=tf.bool,
                                                 start_inputs=start_inputs,
                                                 end_fn=end_fn,
                                                 next_inputs_fn=next_inputs_fn)
                my_decoder = seq2seq.BasicDecoder(
                    cell=cell,
                    helper=helper,
                    initial_state=cell.zero_state(dtype=tf.float32,
                                                  batch_size=batch_size))
                output_size = my_decoder.output_size
                output_dtype = my_decoder.output_dtype
                self.assertEqual(
                    seq2seq.BasicDecoderOutput(cell_depth, cell_depth),
                    output_size)
                self.assertEqual(
                    seq2seq.BasicDecoderOutput(tf.float32, tf.bool),
                    output_dtype)

                (first_finished, first_inputs,
                 first_state) = my_decoder.initialize()
                (step_outputs, step_state, step_next_inputs,
                 step_finished) = my_decoder.step(tf.constant(0), first_inputs,
                                                  first_state)
                batch_size_t = my_decoder.batch_size

                self.assertIsInstance(first_state,
                                      tf.nn.rnn_cell.LSTMStateTuple)
                self.assertIsInstance(step_state,
                                      tf.nn.rnn_cell.LSTMStateTuple)
                self.assertIsInstance(step_outputs, seq2seq.BasicDecoderOutput)
                self.assertEqual((batch_size, cell_depth),
                                 step_outputs[0].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 step_outputs[1].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 first_state[0].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 first_state[1].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 step_state[0].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 step_state[1].get_shape())

                sess.run(tf.global_variables_initializer())
                sess_results = sess.run({
                    "batch_size": batch_size_t,
                    "first_finished": first_finished,
                    "first_inputs": first_inputs,
                    "first_state": first_state,
                    "step_outputs": step_outputs,
                    "step_state": step_state,
                    "step_next_inputs": step_next_inputs,
                    "step_finished": step_finished
                })

                sample_ids = sess_results["step_outputs"].sample_id
                self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
                expected_step_finished = sample_ids[:, end_token]
                expected_step_next_inputs = sample_ids.astype(np.float32)
                self.assertAllEqual(expected_step_finished,
                                    sess_results["step_finished"])
                self.assertAllEqual(expected_step_next_inputs,
                                    sess_results["step_next_inputs"])
示例#2
0
    def _testStepWithScheduledOutputTrainingHelper(  # pylint:disable=invalid-name
            self, sampling_probability, use_next_inputs_fn,
            use_auxiliary_inputs):
        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        cell_depth = input_depth
        if use_auxiliary_inputs:
            auxiliary_input_depth = 4
            auxiliary_inputs = np.random.randn(
                batch_size, max_time, auxiliary_input_depth).astype(np.float32)
        else:
            auxiliary_inputs = None

        with self.session(use_gpu=True) as sess:
            inputs = np.random.randn(batch_size, max_time,
                                     input_depth).astype(np.float32)
            cell = tf.nn.rnn_cell.LSTMCell(cell_depth)
            sampling_probability = tf.constant(sampling_probability)

            if use_next_inputs_fn:

                def next_inputs_fn(outputs):
                    # Use deterministic function for test.
                    samples = tf.argmax(outputs, axis=1)
                    return tf.one_hot(samples, cell_depth, dtype=tf.float32)
            else:
                next_inputs_fn = None

            helper = seq2seq.ScheduledOutputTrainingHelper(
                inputs=inputs,
                sequence_length=sequence_length,
                sampling_probability=sampling_probability,
                time_major=False,
                next_inputs_fn=next_inputs_fn,
                auxiliary_inputs=auxiliary_inputs)

            my_decoder = seq2seq.BasicDecoder(cell=cell,
                                              helper=helper,
                                              initial_state=cell.zero_state(
                                                  dtype=tf.float32,
                                                  batch_size=batch_size))

            output_size = my_decoder.output_size
            output_dtype = my_decoder.output_dtype
            self.assertEqual(
                seq2seq.BasicDecoderOutput(cell_depth, tf.TensorShape([])),
                output_size)
            self.assertEqual(seq2seq.BasicDecoderOutput(tf.float32, tf.int32),
                             output_dtype)

            (first_finished, first_inputs,
             first_state) = my_decoder.initialize()
            (step_outputs, step_state, step_next_inputs,
             step_finished) = my_decoder.step(tf.constant(0), first_inputs,
                                              first_state)

            if use_next_inputs_fn:
                output_after_next_inputs_fn = next_inputs_fn(
                    step_outputs.rnn_output)

            batch_size_t = my_decoder.batch_size

            self.assertIsInstance(first_state, tf.nn.rnn_cell.LSTMStateTuple)
            self.assertIsInstance(step_state, tf.nn.rnn_cell.LSTMStateTuple)
            self.assertIsInstance(step_outputs, seq2seq.BasicDecoderOutput)
            self.assertEqual((batch_size, cell_depth),
                             step_outputs[0].get_shape())
            self.assertEqual((batch_size, ), step_outputs[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[1].get_shape())

            sess.run(tf.global_variables_initializer())

            fetches = {
                "batch_size": batch_size_t,
                "first_finished": first_finished,
                "first_inputs": first_inputs,
                "first_state": first_state,
                "step_outputs": step_outputs,
                "step_state": step_state,
                "step_next_inputs": step_next_inputs,
                "step_finished": step_finished
            }
            if use_next_inputs_fn:
                fetches[
                    "output_after_next_inputs_fn"] = output_after_next_inputs_fn

            sess_results = sess.run(fetches)

            self.assertAllEqual([False, False, False, False, True],
                                sess_results["first_finished"])
            self.assertAllEqual([False, False, False, True, True],
                                sess_results["step_finished"])

            sample_ids = sess_results["step_outputs"].sample_id
            self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
            batch_where_not_sampling = np.where(np.logical_not(sample_ids))
            batch_where_sampling = np.where(sample_ids)

            auxiliary_inputs_to_concat = (
                auxiliary_inputs[:, 1] if use_auxiliary_inputs else np.array(
                    []).reshape(batch_size, 0).astype(np.float32))

            expected_next_sampling_inputs = np.concatenate(
                (sess_results["output_after_next_inputs_fn"]
                 [batch_where_sampling] if use_next_inputs_fn else
                 sess_results["step_outputs"].rnn_output[batch_where_sampling],
                 auxiliary_inputs_to_concat[batch_where_sampling]),
                axis=-1)
            self.assertAllClose(
                sess_results["step_next_inputs"][batch_where_sampling],
                expected_next_sampling_inputs)

            self.assertAllClose(
                sess_results["step_next_inputs"][batch_where_not_sampling],
                np.concatenate(
                    (np.squeeze(inputs[batch_where_not_sampling, 1], axis=0),
                     auxiliary_inputs_to_concat[batch_where_not_sampling]),
                    axis=-1))
示例#3
0
    def testStepWithInferenceHelperCategorical(self):
        batch_size = 5
        vocabulary_size = 7
        cell_depth = vocabulary_size
        start_token = 0
        end_token = 6

        start_inputs = tf.one_hot(
            np.ones(batch_size) * start_token, vocabulary_size)

        # The sample function samples categorically from the logits.
        sample_fn = lambda x: seq2seq.categorical_sample(logits=x)
        # The next inputs are a one-hot encoding of the sampled labels.
        next_inputs_fn = (
            lambda x: tf.one_hot(x, vocabulary_size, dtype=tf.float32))
        end_fn = lambda sample_ids: tf.equal(sample_ids, end_token)

        with self.session(use_gpu=True) as sess:
            with tf.variable_scope("testStepWithInferenceHelper",
                                   initializer=tf.constant_initializer(0.01)):
                cell = tf.nn.rnn_cell.LSTMCell(vocabulary_size)
                helper = seq2seq.InferenceHelper(sample_fn,
                                                 sample_shape=(),
                                                 sample_dtype=tf.int32,
                                                 start_inputs=start_inputs,
                                                 end_fn=end_fn,
                                                 next_inputs_fn=next_inputs_fn)
                my_decoder = seq2seq.BasicDecoder(
                    cell=cell,
                    helper=helper,
                    initial_state=cell.zero_state(dtype=tf.float32,
                                                  batch_size=batch_size))
                output_size = my_decoder.output_size
                output_dtype = my_decoder.output_dtype
                self.assertEqual(
                    seq2seq.BasicDecoderOutput(cell_depth, tf.TensorShape([])),
                    output_size)
                self.assertEqual(
                    seq2seq.BasicDecoderOutput(tf.float32, tf.int32),
                    output_dtype)

                (first_finished, first_inputs,
                 first_state) = my_decoder.initialize()
                (step_outputs, step_state, step_next_inputs,
                 step_finished) = my_decoder.step(tf.constant(0), first_inputs,
                                                  first_state)
                batch_size_t = my_decoder.batch_size

                self.assertIsInstance(first_state,
                                      tf.nn.rnn_cell.LSTMStateTuple)
                self.assertIsInstance(step_state,
                                      tf.nn.rnn_cell.LSTMStateTuple)
                self.assertIsInstance(step_outputs, seq2seq.BasicDecoderOutput)
                self.assertEqual((batch_size, cell_depth),
                                 step_outputs[0].get_shape())
                self.assertEqual((batch_size, ), step_outputs[1].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 first_state[0].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 first_state[1].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 step_state[0].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 step_state[1].get_shape())

                sess.run(tf.global_variables_initializer())
                sess_results = sess.run({
                    "batch_size": batch_size_t,
                    "first_finished": first_finished,
                    "first_inputs": first_inputs,
                    "first_state": first_state,
                    "step_outputs": step_outputs,
                    "step_state": step_state,
                    "step_next_inputs": step_next_inputs,
                    "step_finished": step_finished
                })

                sample_ids = sess_results["step_outputs"].sample_id
                self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
                expected_step_finished = (sample_ids == end_token)
                expected_step_next_inputs = np.zeros(
                    (batch_size, vocabulary_size))
                expected_step_next_inputs[np.arange(batch_size),
                                          sample_ids] = 1.0
                self.assertAllEqual(expected_step_finished,
                                    sess_results["step_finished"])
                self.assertAllEqual(expected_step_next_inputs,
                                    sess_results["step_next_inputs"])
示例#4
0
    def _testStepWithTrainingHelper(self, use_output_layer):  # pylint:disable=invalid-name
        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        cell_depth = 10
        output_layer_depth = 3

        with self.session(use_gpu=True) as sess:
            inputs = np.random.randn(batch_size, max_time,
                                     input_depth).astype(np.float32)
            cell = tf.nn.rnn_cell.LSTMCell(cell_depth)
            helper = seq2seq.TrainingHelper(inputs,
                                            sequence_length,
                                            time_major=False)
            if use_output_layer:
                output_layer = tf.layers.Dense(output_layer_depth,
                                               use_bias=False)
                expected_output_depth = output_layer_depth
            else:
                output_layer = None
                expected_output_depth = cell_depth
            my_decoder = seq2seq.BasicDecoder(cell=cell,
                                              helper=helper,
                                              initial_state=cell.zero_state(
                                                  dtype=tf.float32,
                                                  batch_size=batch_size),
                                              output_layer=output_layer)
            output_size = my_decoder.output_size
            output_dtype = my_decoder.output_dtype
            self.assertEqual(
                seq2seq.BasicDecoderOutput(expected_output_depth,
                                           tf.TensorShape([])), output_size)
            self.assertEqual(seq2seq.BasicDecoderOutput(tf.float32, tf.int32),
                             output_dtype)

            (first_finished, first_inputs,
             first_state) = my_decoder.initialize()
            (step_outputs, step_state, step_next_inputs,
             step_finished) = my_decoder.step(tf.constant(0), first_inputs,
                                              first_state)
            batch_size_t = my_decoder.batch_size

            self.assertIsInstance(first_state, tf.nn.rnn_cell.LSTMStateTuple)
            self.assertIsInstance(step_state, tf.nn.rnn_cell.LSTMStateTuple)
            self.assertIsInstance(step_outputs, seq2seq.BasicDecoderOutput)
            self.assertEqual((batch_size, expected_output_depth),
                             step_outputs[0].get_shape())
            self.assertEqual((batch_size, ), step_outputs[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[1].get_shape())

            if use_output_layer:
                # The output layer was accessed
                self.assertEqual(len(output_layer.variables), 1)

            sess.run(tf.global_variables_initializer())
            sess_results = sess.run({
                "batch_size": batch_size_t,
                "first_finished": first_finished,
                "first_inputs": first_inputs,
                "first_state": first_state,
                "step_outputs": step_outputs,
                "step_state": step_state,
                "step_next_inputs": step_next_inputs,
                "step_finished": step_finished
            })

            self.assertAllEqual([False, False, False, False, True],
                                sess_results["first_finished"])
            self.assertAllEqual([False, False, False, True, True],
                                sess_results["step_finished"])
            self.assertEqual(output_dtype.sample_id,
                             sess_results["step_outputs"].sample_id.dtype)
            self.assertAllEqual(
                np.argmax(sess_results["step_outputs"].rnn_output, -1),
                sess_results["step_outputs"].sample_id)