示例#1
0
def test_step_with_inference_helper_multilabel():
    batch_size = 5
    vocabulary_size = 7
    cell_depth = vocabulary_size
    start_token = 0
    end_token = 6

    start_inputs = tf.one_hot(
        np.ones(batch_size, dtype=np.int32) * start_token, vocabulary_size)

    # The sample function samples independent bernoullis from the logits.
    def sample_fn(x):
        return sampler_py.bernoulli_sample(logits=x, dtype=tf.bool)

    # The next inputs are a one-hot encoding of the sampled labels.
    def next_inputs_fn(x):
        return tf.cast(x, tf.float32)

    def end_fn(sample_ids):
        return sample_ids[:, end_token]

    cell = tf.keras.layers.LSTMCell(vocabulary_size)
    sampler = sampler_py.InferenceSampler(
        sample_fn,
        sample_shape=[cell_depth],
        sample_dtype=tf.bool,
        end_fn=end_fn,
        next_inputs_fn=next_inputs_fn,
    )
    initial_state = cell.get_initial_state(batch_size=batch_size,
                                           dtype=tf.float32)
    my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler)
    (first_finished, first_inputs,
     first_state) = my_decoder.initialize(start_inputs,
                                          initial_state=initial_state)
    output_size = my_decoder.output_size
    output_dtype = my_decoder.output_dtype
    assert basic_decoder.BasicDecoderOutput(cell_depth,
                                            cell_depth) == output_size
    assert basic_decoder.BasicDecoderOutput(tf.float32,
                                            tf.bool) == output_dtype

    (
        step_outputs,
        step_state,
        step_next_inputs,
        step_finished,
    ) = my_decoder.step(tf.constant(0), first_inputs, first_state)
    batch_size_t = my_decoder.batch_size

    assert len(first_state) == 2
    assert len(step_state) == 2
    assert isinstance(step_outputs, basic_decoder.BasicDecoderOutput)
    assert (batch_size, cell_depth) == step_outputs[0].shape
    assert (batch_size, cell_depth) == step_outputs[1].shape
    assert (batch_size, cell_depth) == first_state[0].shape
    assert (batch_size, cell_depth) == first_state[1].shape
    assert (batch_size, cell_depth) == step_state[0].shape
    assert (batch_size, cell_depth) == step_state[1].shape

    eval_result = {
        "batch_size": batch_size_t.numpy(),
        "first_finished": first_finished.numpy(),
        "first_inputs": first_inputs.numpy(),
        "first_state": np.asanyarray(first_state),
        "step_outputs": step_outputs,
        "step_state": np.asanyarray(step_state),
        "step_next_inputs": step_next_inputs.numpy(),
        "step_finished": step_finished.numpy(),
    }

    sample_ids = eval_result["step_outputs"].sample_id.numpy()
    assert output_dtype.sample_id == sample_ids.dtype
    expected_step_finished = sample_ids[:, end_token]
    expected_step_next_inputs = sample_ids.astype(np.float32)
    np.testing.assert_equal(expected_step_finished,
                            eval_result["step_finished"])
    np.testing.assert_equal(expected_step_next_inputs,
                            eval_result["step_next_inputs"])
示例#2
0
    def testStepWithInferenceHelperMultilabel(self):
        batch_size = 5
        vocabulary_size = 7
        cell_depth = vocabulary_size
        start_token = 0
        end_token = 6

        start_inputs = tf.one_hot(
            np.ones(batch_size, dtype=np.int32) * start_token, vocabulary_size
        )

        # The sample function samples independent bernoullis from the logits.
        def sample_fn(x):
            return sampler_py.bernoulli_sample(logits=x, dtype=tf.bool)

        # The next inputs are a one-hot encoding of the sampled labels.
        def next_inputs_fn(x):
            return tf.cast(x, tf.float32)

        def end_fn(sample_ids):
            return sample_ids[:, end_token]

        with self.cached_session(use_gpu=True):
            cell = tf.keras.layers.LSTMCell(vocabulary_size)
            sampler = sampler_py.InferenceSampler(
                sample_fn,
                sample_shape=[cell_depth],
                sample_dtype=tf.bool,
                end_fn=end_fn,
                next_inputs_fn=next_inputs_fn,
            )
            initial_state = cell.get_initial_state(
                batch_size=batch_size, dtype=tf.float32
            )
            my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler)
            (first_finished, first_inputs, first_state) = my_decoder.initialize(
                start_inputs, initial_state=initial_state
            )
            output_size = my_decoder.output_size
            output_dtype = my_decoder.output_dtype
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(cell_depth, cell_depth), output_size
            )
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(tf.float32, tf.bool), output_dtype
            )

            (
                step_outputs,
                step_state,
                step_next_inputs,
                step_finished,
            ) = my_decoder.step(tf.constant(0), first_inputs, first_state)
            batch_size_t = my_decoder.batch_size

            assert len(first_state) == 2
            self.assertLen(step_state, 2)
            assert isinstance(step_outputs, basic_decoder.BasicDecoderOutput)
            self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
            self.assertEqual((batch_size, cell_depth), step_outputs[1].get_shape())
            self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
            self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
            assert (batch_size, cell_depth) == step_state[1].get_shape()

            self.evaluate(tf.compat.v1.global_variables_initializer())
            eval_result = self.evaluate(
                {
                    "batch_size": batch_size_t,
                    "first_finished": first_finished,
                    "first_inputs": first_inputs,
                    "first_state": first_state,
                    "step_outputs": step_outputs,
                    "step_state": step_state,
                    "step_next_inputs": step_next_inputs,
                    "step_finished": step_finished,
                }
            )

            sample_ids = eval_result["step_outputs"].sample_id
            self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
            expected_step_finished = sample_ids[:, end_token]
            expected_step_next_inputs = sample_ids.astype(np.float32)
            self.assertAllEqual(expected_step_finished, eval_result["step_finished"])
            self.assertAllEqual(
                expected_step_next_inputs, eval_result["step_next_inputs"]
            )
示例#3
0
def test_step_with_inference_helper_categorical():
    batch_size = 5
    vocabulary_size = 7
    cell_depth = vocabulary_size
    start_token = 0
    end_token = 6

    start_inputs = tf.one_hot(
        np.ones(batch_size, dtype=np.int32) * start_token, vocabulary_size)

    # The sample function samples categorically from the logits.
    def sample_fn(x):
        return sampler_py.categorical_sample(logits=x)

    # The next inputs are a one-hot encoding of the sampled labels.
    def next_inputs_fn(x):
        return tf.one_hot(x, vocabulary_size, dtype=tf.float32)

    def end_fn(sample_ids):
        return tf.equal(sample_ids, end_token)

    cell = tf.keras.layers.LSTMCell(vocabulary_size)
    sampler = sampler_py.InferenceSampler(
        sample_fn,
        sample_shape=(),
        sample_dtype=tf.int32,
        end_fn=end_fn,
        next_inputs_fn=next_inputs_fn,
    )
    initial_state = cell.get_initial_state(batch_size=batch_size,
                                           dtype=tf.float32)
    my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler)
    (first_finished, first_inputs,
     first_state) = my_decoder.initialize(start_inputs,
                                          initial_state=initial_state)

    output_size = my_decoder.output_size
    output_dtype = my_decoder.output_dtype
    assert (basic_decoder.BasicDecoderOutput(cell_depth, tf.TensorShape(
        [])) == output_size)
    assert basic_decoder.BasicDecoderOutput(tf.float32,
                                            tf.int32) == output_dtype

    (
        step_outputs,
        step_state,
        step_next_inputs,
        step_finished,
    ) = my_decoder.step(tf.constant(0), first_inputs, first_state)

    assert len(first_state) == 2
    assert len(step_state) == 2
    assert isinstance(step_outputs, basic_decoder.BasicDecoderOutput)
    assert (batch_size, cell_depth) == step_outputs[0].shape
    assert (batch_size, ) == step_outputs[1].shape
    assert (batch_size, cell_depth) == first_state[0].shape
    assert (batch_size, cell_depth) == first_state[1].shape
    assert (batch_size, cell_depth) == step_state[0].shape
    assert (batch_size, cell_depth) == step_state[1].shape

    sample_ids = step_outputs.sample_id.numpy()
    assert output_dtype.sample_id == sample_ids.dtype
    expected_step_finished = sample_ids == end_token
    expected_step_next_inputs = np.zeros((batch_size, vocabulary_size))
    expected_step_next_inputs[np.arange(batch_size), sample_ids] = 1.0
    np.testing.assert_equal(expected_step_finished, step_finished)
    np.testing.assert_equal(expected_step_next_inputs, step_next_inputs)
示例#4
0
    def testStepWithInferenceHelperCategorical(self):
        batch_size = 5
        vocabulary_size = 7
        cell_depth = vocabulary_size
        start_token = 0
        end_token = 6

        start_inputs = tf.one_hot(
            np.ones(batch_size, dtype=np.int32) * start_token, vocabulary_size
        )

        # The sample function samples categorically from the logits.
        def sample_fn(x):
            return sampler_py.categorical_sample(logits=x)

        # The next inputs are a one-hot encoding of the sampled labels.
        def next_inputs_fn(x):
            return tf.one_hot(x, vocabulary_size, dtype=tf.float32)

        def end_fn(sample_ids):
            return tf.equal(sample_ids, end_token)

        with self.cached_session(use_gpu=True):
            cell = tf.keras.layers.LSTMCell(vocabulary_size)
            sampler = sampler_py.InferenceSampler(
                sample_fn,
                sample_shape=(),
                sample_dtype=tf.int32,
                end_fn=end_fn,
                next_inputs_fn=next_inputs_fn,
            )
            initial_state = cell.get_initial_state(
                batch_size=batch_size, dtype=tf.float32
            )
            my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler)
            (first_finished, first_inputs, first_state) = my_decoder.initialize(
                start_inputs, initial_state=initial_state
            )

            output_size = my_decoder.output_size
            output_dtype = my_decoder.output_dtype
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(cell_depth, tf.TensorShape([])),
                output_size,
            )
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(tf.float32, tf.int32), output_dtype
            )

            (
                step_outputs,
                step_state,
                step_next_inputs,
                step_finished,
            ) = my_decoder.step(tf.constant(0), first_inputs, first_state)
            batch_size_t = my_decoder.batch_size

            self.assertLen(first_state, 2)
            self.assertLen(step_state, 2)
            self.assertTrue(isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
            self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
            self.assertEqual((batch_size,), step_outputs[1].get_shape())
            self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
            self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

            self.evaluate(tf.compat.v1.global_variables_initializer())
            eval_result = self.evaluate(
                {
                    "batch_size": batch_size_t,
                    "first_finished": first_finished,
                    "first_inputs": first_inputs,
                    "first_state": first_state,
                    "step_outputs": step_outputs,
                    "step_state": step_state,
                    "step_next_inputs": step_next_inputs,
                    "step_finished": step_finished,
                }
            )

            sample_ids = eval_result["step_outputs"].sample_id
            self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
            expected_step_finished = sample_ids == end_token
            expected_step_next_inputs = np.zeros((batch_size, vocabulary_size))
            expected_step_next_inputs[np.arange(batch_size), sample_ids] = 1.0
            self.assertAllEqual(expected_step_finished, eval_result["step_finished"])
            self.assertAllEqual(
                expected_step_next_inputs, eval_result["step_next_inputs"]
            )
示例#5
0
    def testStepWithInferenceHelperMultilabel(self):
        batch_size = 5
        vocabulary_size = 7
        cell_depth = vocabulary_size
        start_token = 0
        end_token = 6

        start_inputs = array_ops.one_hot(
            np.ones(batch_size, dtype=np.int32) * start_token, vocabulary_size)

        # The sample function samples independent bernoullis from the logits.
        sample_fn = (
            lambda x: sampler_py.bernoulli_sample(logits=x, dtype=dtypes.bool))
        # The next inputs are a one-hot encoding of the sampled labels.
        next_inputs_fn = math_ops.to_float
        end_fn = lambda sample_ids: sample_ids[:, end_token]

        with self.cached_session(use_gpu=True):
            cell = rnn_cell.LSTMCell(vocabulary_size)
            sampler = sampler_py.InferenceSampler(
                sample_fn,
                sample_shape=[cell_depth],
                sample_dtype=dtypes.bool,
                end_fn=end_fn,
                next_inputs_fn=next_inputs_fn)
            initial_state = cell.zero_state(dtype=dtypes.float32,
                                            batch_size=batch_size)
            my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler)
            (first_finished, first_inputs,
             first_state) = my_decoder.initialize(start_inputs,
                                                  initial_state=initial_state)
            output_size = my_decoder.output_size
            output_dtype = my_decoder.output_dtype
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(cell_depth, cell_depth),
                output_size)
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.bool),
                output_dtype)

            (step_outputs, step_state, step_next_inputs,
             step_finished) = my_decoder.step(constant_op.constant(0),
                                              first_inputs, first_state)
            batch_size_t = my_decoder.batch_size

            self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
            self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
            self.assertTrue(
                isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
            self.assertEqual((batch_size, cell_depth),
                             step_outputs[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_outputs[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[1].get_shape())

            self.evaluate(variables.global_variables_initializer())
            eval_result = self.evaluate({
                "batch_size": batch_size_t,
                "first_finished": first_finished,
                "first_inputs": first_inputs,
                "first_state": first_state,
                "step_outputs": step_outputs,
                "step_state": step_state,
                "step_next_inputs": step_next_inputs,
                "step_finished": step_finished
            })

            sample_ids = eval_result["step_outputs"].sample_id
            self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
            expected_step_finished = sample_ids[:, end_token]
            expected_step_next_inputs = sample_ids.astype(np.float32)
            self.assertAllEqual(expected_step_finished,
                                eval_result["step_finished"])
            self.assertAllEqual(expected_step_next_inputs,
                                eval_result["step_next_inputs"])