Ejemplo n.º 1
0
def test_step_with_scheduled_output_training_helper(sampling_probability,
                                                    use_next_inputs_fn,
                                                    use_auxiliary_inputs):
    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    cell_depth = input_depth
    if use_auxiliary_inputs:
        auxiliary_input_depth = 4
        auxiliary_inputs = np.random.randn(
            batch_size, max_time, auxiliary_input_depth).astype(np.float32)
    else:
        auxiliary_inputs = None

    inputs = np.random.randn(batch_size, max_time,
                             input_depth).astype(np.float32)
    input_t = tf.constant(inputs)
    cell = tf.keras.layers.LSTMCell(cell_depth)
    sampling_probability = tf.constant(sampling_probability)

    if use_next_inputs_fn:

        def next_inputs_fn(outputs):
            # Use deterministic function for test.
            samples = tf.argmax(outputs, axis=1)
            return tf.one_hot(samples, cell_depth, dtype=tf.float32)

    else:
        next_inputs_fn = None

    sampler = sampler_py.ScheduledOutputTrainingSampler(
        sampling_probability=sampling_probability,
        time_major=False,
        next_inputs_fn=next_inputs_fn,
    )
    initial_state = cell.get_initial_state(batch_size=batch_size,
                                           dtype=tf.float32)

    my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler)

    (first_finished, first_inputs, first_state) = my_decoder.initialize(
        input_t,
        sequence_length=sequence_length,
        initial_state=initial_state,
        auxiliary_inputs=auxiliary_inputs,
    )
    output_size = my_decoder.output_size
    output_dtype = my_decoder.output_dtype
    assert (basic_decoder.BasicDecoderOutput(cell_depth, tf.TensorShape(
        [])) == output_size)
    assert basic_decoder.BasicDecoderOutput(tf.float32,
                                            tf.int32) == output_dtype

    (
        step_outputs,
        step_state,
        step_next_inputs,
        step_finished,
    ) = my_decoder.step(tf.constant(0), first_inputs, first_state)

    if use_next_inputs_fn:
        output_after_next_inputs_fn = next_inputs_fn(step_outputs.rnn_output)

    batch_size_t = my_decoder.batch_size

    assert len(first_state) == 2
    assert len(step_state) == 2
    assert isinstance(step_outputs, basic_decoder.BasicDecoderOutput)
    assert (batch_size, cell_depth) == step_outputs[0].shape
    assert (batch_size, ) == step_outputs[1].shape
    assert (batch_size, cell_depth) == first_state[0].shape
    assert (batch_size, cell_depth) == first_state[1].shape
    assert (batch_size, cell_depth) == step_state[0].shape
    assert (batch_size, cell_depth) == step_state[1].shape

    fetches = {
        "batch_size": batch_size_t.numpy(),
        "first_finished": first_finished.numpy(),
        "first_inputs": first_inputs.numpy(),
        "first_state": np.asanyarray(first_state),
        "step_outputs": step_outputs,
        "step_state": np.asanyarray(step_state),
        "step_next_inputs": step_next_inputs.numpy(),
        "step_finished": step_finished.numpy(),
    }

    if use_next_inputs_fn:
        fetches["output_after_next_inputs_fn"] = output_after_next_inputs_fn

    eval_result = fetches

    np.testing.assert_equal(
        np.asanyarray([False, False, False, False, True]),
        eval_result["first_finished"],
    )
    np.testing.assert_equal(
        np.asanyarray([False, False, False, True, True]),
        eval_result["step_finished"],
    )

    sample_ids = eval_result["step_outputs"].sample_id.numpy()
    assert output_dtype.sample_id == sample_ids.dtype
    batch_where_not_sampling = np.where(np.logical_not(sample_ids))
    batch_where_sampling = np.where(sample_ids)

    auxiliary_inputs_to_concat = (auxiliary_inputs[:,
                                                   1] if use_auxiliary_inputs
                                  else np.array([]).reshape(
                                      batch_size, 0).astype(np.float32))

    expected_next_sampling_inputs = np.concatenate(
        (
            eval_result["output_after_next_inputs_fn"].numpy()
            [batch_where_sampling] if use_next_inputs_fn else
            eval_result["step_outputs"].rnn_output.numpy()
            [batch_where_sampling],
            auxiliary_inputs_to_concat[batch_where_sampling],
        ),
        axis=-1,
    )

    np.testing.assert_equal(
        eval_result["step_next_inputs"][batch_where_sampling],
        expected_next_sampling_inputs,
    )

    np.testing.assert_equal(
        eval_result["step_next_inputs"][batch_where_not_sampling],
        np.concatenate(
            (
                np.squeeze(inputs[batch_where_not_sampling, 1], axis=0),
                auxiliary_inputs_to_concat[batch_where_not_sampling],
            ),
            axis=-1,
        ),
    )
Ejemplo n.º 2
0
    def _testStepWithScheduledOutputTrainingHelper(
        self, sampling_probability, use_next_inputs_fn, use_auxiliary_inputs
    ):
        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        cell_depth = input_depth
        if use_auxiliary_inputs:
            auxiliary_input_depth = 4
            auxiliary_inputs = np.random.randn(
                batch_size, max_time, auxiliary_input_depth
            ).astype(np.float32)
        else:
            auxiliary_inputs = None

        with self.cached_session(use_gpu=True):
            inputs = np.random.randn(batch_size, max_time, input_depth).astype(
                np.float32
            )
            input_t = tf.constant(inputs)
            cell = tf.keras.layers.LSTMCell(cell_depth)
            sampling_probability = tf.constant(sampling_probability)

            if use_next_inputs_fn:

                def next_inputs_fn(outputs):
                    # Use deterministic function for test.
                    samples = tf.argmax(outputs, axis=1)
                    return tf.one_hot(samples, cell_depth, dtype=tf.float32)

            else:
                next_inputs_fn = None

            sampler = sampler_py.ScheduledOutputTrainingSampler(
                sampling_probability=sampling_probability,
                time_major=False,
                next_inputs_fn=next_inputs_fn,
            )
            initial_state = cell.get_initial_state(
                batch_size=batch_size, dtype=tf.float32
            )

            my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler)

            (first_finished, first_inputs, first_state) = my_decoder.initialize(
                input_t,
                sequence_length=sequence_length,
                initial_state=initial_state,
                auxiliary_inputs=auxiliary_inputs,
            )
            output_size = my_decoder.output_size
            output_dtype = my_decoder.output_dtype
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(cell_depth, tf.TensorShape([])),
                output_size,
            )
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(tf.float32, tf.int32), output_dtype
            )

            (
                step_outputs,
                step_state,
                step_next_inputs,
                step_finished,
            ) = my_decoder.step(tf.constant(0), first_inputs, first_state)

            if use_next_inputs_fn:
                output_after_next_inputs_fn = next_inputs_fn(step_outputs.rnn_output)

            batch_size_t = my_decoder.batch_size

            self.assertLen(first_state, 2)
            self.assertLen(step_state, 2)
            self.assertTrue(isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
            self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
            self.assertEqual((batch_size,), step_outputs[1].get_shape())
            self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
            self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

            self.evaluate(tf.compat.v1.global_variables_initializer())

            fetches = {
                "batch_size": batch_size_t,
                "first_finished": first_finished,
                "first_inputs": first_inputs,
                "first_state": first_state,
                "step_outputs": step_outputs,
                "step_state": step_state,
                "step_next_inputs": step_next_inputs,
                "step_finished": step_finished,
            }
            if use_next_inputs_fn:
                fetches["output_after_next_inputs_fn"] = output_after_next_inputs_fn

            eval_result = self.evaluate(fetches)

            self.assertAllEqual(
                [False, False, False, False, True], eval_result["first_finished"]
            )
            self.assertAllEqual(
                [False, False, False, True, True], eval_result["step_finished"]
            )

            sample_ids = eval_result["step_outputs"].sample_id
            self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
            batch_where_not_sampling = np.where(np.logical_not(sample_ids))
            batch_where_sampling = np.where(sample_ids)

            auxiliary_inputs_to_concat = (
                auxiliary_inputs[:, 1]
                if use_auxiliary_inputs
                else np.array([]).reshape(batch_size, 0).astype(np.float32)
            )

            expected_next_sampling_inputs = np.concatenate(
                (
                    eval_result["output_after_next_inputs_fn"][batch_where_sampling]
                    if use_next_inputs_fn
                    else eval_result["step_outputs"].rnn_output[batch_where_sampling],
                    auxiliary_inputs_to_concat[batch_where_sampling],
                ),
                axis=-1,
            )
            self.assertAllClose(
                eval_result["step_next_inputs"][batch_where_sampling],
                expected_next_sampling_inputs,
            )

            self.assertAllClose(
                eval_result["step_next_inputs"][batch_where_not_sampling],
                np.concatenate(
                    (
                        np.squeeze(inputs[batch_where_not_sampling, 1], axis=0),
                        auxiliary_inputs_to_concat[batch_where_not_sampling],
                    ),
                    axis=-1,
                ),
            )