Esempio n. 1
0
def test_dynamic_decode_rnn_with_scheduled_embedding_training_sampler():
    policy = tf.keras.mixed_precision.experimental.global_policy()
    sequence_length = [3, 4, 3, 1]
    batch_size = 4
    input_depth = 7
    cell_depth = 10
    vocab_size = 12
    max_time = max(sequence_length)

    embedding = tf.keras.layers.Embedding(vocab_size, input_depth)
    cell = tf.keras.layers.LSTMCell(cell_depth)
    sampler = sampler_py.ScheduledEmbeddingTrainingSampler(
        sampling_probability=tf.constant(1.0), embedding_fn=embedding
    )
    my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler)

    inputs = tf.random.uniform([batch_size, max_time, input_depth])
    initial_state = cell.get_initial_state(
        batch_size=batch_size, dtype=policy.compute_dtype
    )
    final_outputs, _, _ = my_decoder(
        inputs, initial_state=initial_state, sequence_length=sequence_length
    )

    assert final_outputs.rnn_output.dtype == policy.compute_dtype
Esempio n. 2
0
def test_step_with_scheduled_embedding_training_helper():
    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    vocabulary_size = 10

    inputs = np.random.randn(batch_size, max_time,
                             input_depth).astype(np.float32)
    input_t = tf.constant(inputs)
    embeddings = np.random.randn(vocabulary_size,
                                 input_depth).astype(np.float32)
    half = tf.constant(0.5)
    cell = tf.keras.layers.LSTMCell(vocabulary_size)
    sampler = sampler_py.ScheduledEmbeddingTrainingSampler(
        sampling_probability=half, time_major=False)
    initial_state = cell.get_initial_state(batch_size=batch_size,
                                           dtype=tf.float32)
    my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler)
    (first_finished, first_inputs, first_state) = my_decoder.initialize(
        input_t,
        sequence_length=sequence_length,
        embedding=embeddings,
        initial_state=initial_state,
    )
    output_size = my_decoder.output_size
    output_dtype = my_decoder.output_dtype
    assert (basic_decoder.BasicDecoderOutput(vocabulary_size, tf.TensorShape(
        [])) == output_size)

    assert basic_decoder.BasicDecoderOutput(tf.float32,
                                            tf.int32) == output_dtype

    (
        step_outputs,
        step_state,
        step_next_inputs,
        step_finished,
    ) = my_decoder.step(tf.constant(0), first_inputs, first_state)
    batch_size_t = my_decoder.batch_size

    assert len(first_state) == 2
    assert len(step_state) == 2
    assert isinstance(step_outputs, basic_decoder.BasicDecoderOutput)
    assert (batch_size, vocabulary_size) == step_outputs[0].shape
    assert (batch_size, ) == step_outputs[1].shape
    assert (batch_size, vocabulary_size) == first_state[0].shape
    assert (batch_size, vocabulary_size) == first_state[1].shape
    assert (batch_size, vocabulary_size) == step_state[0].shape
    assert (batch_size, vocabulary_size) == step_state[1].shape
    assert (batch_size, input_depth) == step_next_inputs.shape

    eval_result = {
        "batch_size": batch_size_t.numpy(),
        "first_finished": first_finished.numpy(),
        "first_inputs": first_inputs.numpy(),
        "first_state": np.asanyarray(first_state),
        "step_outputs": step_outputs,
        "step_state": np.asanyarray(step_state),
        "step_next_inputs": step_next_inputs.numpy(),
        "step_finished": step_finished.numpy(),
    }

    np.testing.assert_equal(
        np.asanyarray([False, False, False, False, True]),
        eval_result["first_finished"],
    )
    np.testing.assert_equal(
        np.asanyarray([False, False, False, True, True]),
        eval_result["step_finished"],
    )
    sample_ids = eval_result["step_outputs"].sample_id.numpy()
    assert output_dtype.sample_id == sample_ids.dtype
    batch_where_not_sampling = np.where(sample_ids == -1)
    batch_where_sampling = np.where(sample_ids > -1)

    np.testing.assert_equal(
        eval_result["step_next_inputs"][batch_where_sampling],
        embeddings[sample_ids[batch_where_sampling]],
    )
    np.testing.assert_equal(
        eval_result["step_next_inputs"][batch_where_not_sampling],
        np.squeeze(inputs[batch_where_not_sampling, 1], axis=0),
    )
Esempio n. 3
0
    def testStepWithScheduledEmbeddingTrainingHelper(self):
        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        vocabulary_size = 10

        with self.cached_session(use_gpu=True):
            inputs = np.random.randn(batch_size, max_time, input_depth).astype(
                np.float32
            )
            input_t = tf.constant(inputs)
            embeddings = np.random.randn(vocabulary_size, input_depth).astype(
                np.float32
            )
            half = tf.constant(0.5)
            cell = tf.keras.layers.LSTMCell(vocabulary_size)
            sampler = sampler_py.ScheduledEmbeddingTrainingSampler(
                sampling_probability=half, time_major=False
            )
            initial_state = cell.get_initial_state(
                batch_size=batch_size, dtype=tf.float32
            )
            my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler)
            (first_finished, first_inputs, first_state) = my_decoder.initialize(
                input_t,
                sequence_length=sequence_length,
                embedding=embeddings,
                initial_state=initial_state,
            )
            output_size = my_decoder.output_size
            output_dtype = my_decoder.output_dtype
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(vocabulary_size, tf.TensorShape([])),
                output_size,
            )
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(tf.float32, tf.int32), output_dtype
            )

            (
                step_outputs,
                step_state,
                step_next_inputs,
                step_finished,
            ) = my_decoder.step(tf.constant(0), first_inputs, first_state)
            batch_size_t = my_decoder.batch_size

            self.assertLen(first_state, 2)
            self.assertLen(step_state, 2)
            self.assertTrue(isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
            self.assertEqual((batch_size, vocabulary_size), step_outputs[0].get_shape())
            self.assertEqual((batch_size,), step_outputs[1].get_shape())
            self.assertEqual((batch_size, vocabulary_size), first_state[0].get_shape())
            self.assertEqual((batch_size, vocabulary_size), first_state[1].get_shape())
            self.assertEqual((batch_size, vocabulary_size), step_state[0].get_shape())
            self.assertEqual((batch_size, vocabulary_size), step_state[1].get_shape())
            self.assertEqual((batch_size, input_depth), step_next_inputs.get_shape())

            self.evaluate(tf.compat.v1.global_variables_initializer())
            eval_result = self.evaluate(
                {
                    "batch_size": batch_size_t,
                    "first_finished": first_finished,
                    "first_inputs": first_inputs,
                    "first_state": first_state,
                    "step_outputs": step_outputs,
                    "step_state": step_state,
                    "step_next_inputs": step_next_inputs,
                    "step_finished": step_finished,
                }
            )

            self.assertAllEqual(
                [False, False, False, False, True], eval_result["first_finished"]
            )
            self.assertAllEqual(
                [False, False, False, True, True], eval_result["step_finished"]
            )
            sample_ids = eval_result["step_outputs"].sample_id
            self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
            batch_where_not_sampling = np.where(sample_ids == -1)
            batch_where_sampling = np.where(sample_ids > -1)
            self.assertAllClose(
                eval_result["step_next_inputs"][batch_where_sampling],
                embeddings[sample_ids[batch_where_sampling]],
            )
            self.assertAllClose(
                eval_result["step_next_inputs"][batch_where_not_sampling],
                np.squeeze(inputs[batch_where_not_sampling, 1], axis=0),
            )