Beispiel #1
0
  def testBahdanauMonotonicNormalized(self):
    create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2
    create_attention_kwargs = {"kernel_initializer": "ones",
                               "normalize": True}
    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=4.5706983),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=rnn_cell.LSTMStateTuple(
            c=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=1.6005473),
            h=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.77863038)),
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=7.3326721),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.12258384),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.12258384),
        alignment_history=())
    expected_final_alignment_history = ResultSummary(
        shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12258384)

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        alignment_history=True,
        expected_final_alignment_history=expected_final_alignment_history,
        create_query_layer=True,
        create_attention_kwargs=create_attention_kwargs)
Beispiel #2
0
    def testLuongNotNormalized(self):
        create_attention_mechanism = wrapper.LuongAttentionV2

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=ResultSummary(shape=(5, 3, 6),
                                     dtype=np.dtype("float32"),
                                     mean=0.05481226),
            sample_id=ResultSummary(shape=(5, 3),
                                    dtype=np.dtype("int32"),
                                    mean=3.13333333))
        expected_final_state = wrapper.AttentionWrapperState(
            cell_state=[
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.38453412),
                ResultSummary(shape=(5, 9),
                              dtype=np.dtype("float32"),
                              mean=0.5785929)
            ],
            attention=ResultSummary(shape=(5, 6),
                                    dtype=np.dtype("float32"),
                                    mean=0.16311775),
            time=3,
            alignments=ResultSummary(shape=(5, 8),
                                     dtype=np.dtype("float32"),
                                     mean=0.125),
            attention_state=ResultSummary(shape=(5, 8),
                                          dtype=np.dtype("float32"),
                                          mean=0.125),
            alignment_history=())

        self._testWithAttention(create_attention_mechanism,
                                expected_final_output,
                                expected_final_state,
                                attention_mechanism_depth=9)
  def testBahdanauMonotonicNormalized(self):
    create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2
    create_attention_kwargs = {"kernel_initializer": "ones",
                               "normalize": True}
    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.043294173),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=3.53333333))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=[
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.40034312),
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.5925445)],
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=0.096119694),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.1211452),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.1211452),
        alignment_history=())
    expected_final_alignment_history = ResultSummary(
        shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12258384)

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        alignment_history=True,
        expected_final_alignment_history=expected_final_alignment_history,
        create_query_layer=True,
        create_attention_kwargs=create_attention_kwargs)
Beispiel #4
0
  def testLuongMonotonicScaled(self):
    create_attention_mechanism = wrapper.LuongMonotonicAttentionV2
    create_attention_kwargs = {"scale": True}

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=3.159497),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=rnn_cell.LSTMStateTuple(
            c=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384),
            h=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038)),
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.11467695),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.11467695),
        alignment_history=())
    expected_final_alignment_history = ResultSummary(
        shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.11899644)

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        attention_mechanism_depth=9,
        alignment_history=True,
        expected_final_alignment_history=expected_final_alignment_history,
        create_attention_kwargs=create_attention_kwargs)
  def testNotUseAttentionLayer(self):
    create_attention_mechanism = wrapper.BahdanauAttentionV2
    create_attention_kwargs = {"kernel_initializer": "ones"}

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 10), dtype=np.dtype("float32"), mean=0.072406612),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=3.86666666))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=[
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.61177742),
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=1.032002)],
        attention=ResultSummary(
            shape=(5, 10), dtype=np.dtype("float32"), mean=0.011346335),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.125),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.125),
        alignment_history=())

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        attention_layer_size=None,
        create_query_layer=True,
        create_attention_kwargs=create_attention_kwargs)
  def testBahdanauMonotonicNotNormalized(self):
    create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2
    create_attention_kwargs = {"kernel_initializer": "ones"}

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.041342419),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=3.53333333))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=[
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.33866978),
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.46913195)],
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=0.092498459),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.12079944),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.12079944),
        alignment_history=())
    expected_final_alignment_history = ResultSummary(
        shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.121448785067)

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        alignment_history=True,
        expected_final_alignment_history=expected_final_alignment_history,
        create_query_layer=True,
        create_attention_kwargs=create_attention_kwargs)
  def testBahdanauNotNormalized(self):
    create_attention_mechanism = wrapper.BahdanauAttentionV2
    create_attention_kwargs = {"kernel_initializer": "ones"}
    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=0.051747426),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype(np.int32), mean=3.33333333))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=[
            ResultSummary(
                shape=(5, 9), dtype=np.dtype(np.float32), mean=0.44189346),
            ResultSummary(
                shape=(5, 9), dtype=np.dtype(np.float32), mean=0.65429491)],
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype(np.float32), mean=0.073610783),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125),
        alignment_history=())
    expected_final_alignment_history = ResultSummary(
        shape=(3, 5, 8), dtype=np.dtype(np.float32), mean=0.125)

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        alignment_history=True,
        create_query_layer=True,
        expected_final_alignment_history=expected_final_alignment_history,
        create_attention_kwargs=create_attention_kwargs)
  def testBahdanauNormalized(self):
    create_attention_mechanism = wrapper.BahdanauAttentionV2
    create_attention_kwargs = {"kernel_initializer": "ones", "normalize": True}

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.047594748),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=3.6))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=[
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.41311637),
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.61683208)],
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=0.090581432),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.125),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.125),
        alignment_history=())

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        create_query_layer=True,
        create_attention_kwargs=create_attention_kwargs)
  def testLuongMonotonicScaled(self):
    create_attention_mechanism = wrapper.LuongMonotonicAttentionV2
    create_attention_kwargs = {"scale": True}

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.027387079),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=3.13333333))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=[
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.32660431),
            ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.52464348)],
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=0.089345723),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.11831035),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.11831035),
        alignment_history=())
    expected_final_alignment_history = ResultSummary(
        shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12194442004)

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        attention_mechanism_depth=9,
        alignment_history=True,
        expected_final_alignment_history=expected_final_alignment_history,
        create_attention_kwargs=create_attention_kwargs)
Beispiel #10
0
  def testLuongScaled(self):
    create_attention_mechanism = wrapper.LuongAttentionV2
    create_attention_kwargs = {"scale": True}

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=2.6605489),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=rnn_cell.LSTMStateTuple(
            c=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547),
            h=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088)),
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=4.0846314),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.125),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.125),
        alignment_history=())

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        attention_mechanism_depth=9,
        create_attention_kwargs=create_attention_kwargs)
Beispiel #11
0
  def testBahdanauMonotonicNotNormalized(self):
    create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2
    create_attention_kwargs = {"kernel_initializer": "ones"}

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=5.9850435),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=rnn_cell.LSTMStateTuple(
            c=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=1.6752492),
            h=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.76052248)),
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=8.361186),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.10989678),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.10989678),
        alignment_history=())
    expected_final_alignment_history = ResultSummary(
        shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.117412611)

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        alignment_history=True,
        expected_final_alignment_history=expected_final_alignment_history,
        create_query_layer=True,
        create_attention_kwargs=create_attention_kwargs)
Beispiel #12
0
  def testBahdanauNormalized(self):
    create_attention_mechanism = wrapper.BahdanauAttentionV2
    create_attention_kwargs = {"kernel_initializer": "ones", "normalize": True}

    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype("float32"), mean=3.9548259),
        sample_id=ResultSummary(
            shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=rnn_cell.LSTMStateTuple(
            c=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=1.4652209),
            h=ResultSummary(
                shape=(5, 9), dtype=np.dtype("float32"), mean=0.70997983)),
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype("float32"), mean=6.3075728),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.125),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype("float32"), mean=0.125),
        alignment_history=())

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        create_query_layer=True,
        create_attention_kwargs=create_attention_kwargs)
Beispiel #13
0
  def testBahdanauNotNormalized(self):
    create_attention_mechanism = wrapper.BahdanauAttentionV2
    create_attention_kwargs = {"kernel_initializer": "ones"}
    expected_final_output = basic_decoder.BasicDecoderOutput(
        rnn_output=ResultSummary(
            shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=4.8290324),
        sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype(np.int32), mean=0))
    expected_final_state = wrapper.AttentionWrapperState(
        cell_state=rnn_cell.LSTMStateTuple(
            c=ResultSummary(
                shape=(5, 9), dtype=np.dtype(np.float32), mean=1.6432636),
            h=ResultSummary(
                shape=(5, 9), dtype=np.dtype(np.float32), mean=0.75866824)),
        attention=ResultSummary(
            shape=(5, 6), dtype=np.dtype(np.float32), mean=6.7445569),
        time=3,
        alignments=ResultSummary(
            shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125),
        attention_state=ResultSummary(
            shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125),
        alignment_history=())
    expected_final_alignment_history = ResultSummary(
        shape=(3, 5, 8), dtype=np.dtype(np.float32), mean=0.125)

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        alignment_history=True,
        create_query_layer=True,
        expected_final_alignment_history=expected_final_alignment_history,
        create_attention_kwargs=create_attention_kwargs)
Beispiel #14
0
  def testStepWithSampleEmbeddingHelper(self):
    batch_size = 5
    vocabulary_size = 7
    cell_depth = vocabulary_size  # cell's logits must match vocabulary size
    input_depth = 10
    np.random.seed(0)
    start_tokens = np.random.randint(0, vocabulary_size, size=batch_size)
    end_token = 1

    with self.session(use_gpu=True) as sess:
      with variable_scope.variable_scope(
          "testStepWithSampleEmbeddingHelper",
          initializer=init_ops.constant_initializer(0.01)):
        embeddings = np.random.randn(vocabulary_size,
                                     input_depth).astype(np.float32)
        cell = rnn_cell.LSTMCell(vocabulary_size)
        helper = helper_py.SampleEmbeddingHelper(embeddings, start_tokens,
                                                 end_token, seed=0)
        my_decoder = basic_decoder.BasicDecoder(
            cell=cell,
            helper=helper,
            initial_state=cell.zero_state(
                dtype=dtypes.float32, batch_size=batch_size))
        output_size = my_decoder.output_size
        output_dtype = my_decoder.output_dtype
        self.assertEqual(
            basic_decoder.BasicDecoderOutput(cell_depth,
                                             tensor_shape.TensorShape([])),
            output_size)
        self.assertEqual(
            basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
            output_dtype)

        (first_finished, first_inputs, first_state) = my_decoder.initialize()
        (step_outputs, step_state, step_next_inputs,
         step_finished) = my_decoder.step(
             constant_op.constant(0), first_inputs, first_state)
        batch_size_t = my_decoder.batch_size

        self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
        self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
        self.assertTrue(
            isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
        self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
        self.assertEqual((batch_size,), step_outputs[1].get_shape())
        self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
        self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
        self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
        self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

        sess.run(variables.global_variables_initializer())
        sess_results = sess.run({
            "batch_size": batch_size_t,
            "first_finished": first_finished,
            "first_inputs": first_inputs,
            "first_state": first_state,
            "step_outputs": step_outputs,
            "step_state": step_state,
            "step_next_inputs": step_next_inputs,
            "step_finished": step_finished
        })

        sample_ids = sess_results["step_outputs"].sample_id
        self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
        expected_step_finished = (sample_ids == end_token)
        expected_step_next_inputs = embeddings[sample_ids]
        self.assertAllEqual(expected_step_finished,
                            sess_results["step_finished"])
        self.assertAllEqual(expected_step_next_inputs,
                            sess_results["step_next_inputs"])
    def _testStepWithScheduledOutputTrainingHelper(self, use_next_input_layer):
        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        cell_depth = input_depth
        if use_next_input_layer:
            cell_depth = 6

        with self.test_session() as sess:
            inputs = np.random.randn(batch_size, max_time,
                                     input_depth).astype(np.float32)
            cell = core_rnn_cell.LSTMCell(cell_depth)
            half = constant_op.constant(0.5)

            next_input_layer = None
            if use_next_input_layer:
                next_input_layer = layers_core.Dense(input_depth,
                                                     use_bias=False)

            helper = helper_py.ScheduledOutputTrainingHelper(
                inputs=inputs,
                sequence_length=sequence_length,
                sampling_probability=half,
                time_major=False,
                next_input_layer=next_input_layer)

            my_decoder = basic_decoder.BasicDecoder(
                cell=cell,
                helper=helper,
                initial_state=cell.zero_state(dtype=dtypes.float32,
                                              batch_size=batch_size))

            output_size = my_decoder.output_size
            output_dtype = my_decoder.output_dtype
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(cell_depth,
                                                 tensor_shape.TensorShape([])),
                output_size)
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
                output_dtype)

            (first_finished, first_inputs,
             first_state) = my_decoder.initialize()
            (step_outputs, step_state, step_next_inputs,
             step_finished) = my_decoder.step(constant_op.constant(0),
                                              first_inputs, first_state)

            if use_next_input_layer:
                output_after_next_input_layer = next_input_layer(
                    step_outputs.rnn_output)

            batch_size_t = my_decoder.batch_size

            self.assertTrue(
                isinstance(first_state, core_rnn_cell.LSTMStateTuple))
            self.assertTrue(
                isinstance(step_state, core_rnn_cell.LSTMStateTuple))
            self.assertTrue(
                isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
            self.assertEqual((batch_size, cell_depth),
                             step_outputs[0].get_shape())
            self.assertEqual((batch_size, ), step_outputs[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[1].get_shape())

            sess.run(variables.global_variables_initializer())

            fetches = {
                "batch_size": batch_size_t,
                "first_finished": first_finished,
                "first_inputs": first_inputs,
                "first_state": first_state,
                "step_outputs": step_outputs,
                "step_state": step_state,
                "step_next_inputs": step_next_inputs,
                "step_finished": step_finished
            }
            if use_next_input_layer:
                fetches[
                    "output_after_next_input_layer"] = output_after_next_input_layer

            sess_results = sess.run(fetches)

            self.assertAllEqual([False, False, False, False, True],
                                sess_results["first_finished"])
            self.assertAllEqual([False, False, False, True, True],
                                sess_results["step_finished"])

            sample_ids = sess_results["step_outputs"].sample_id
            batch_where_not_sampling = np.where(np.logical_not(sample_ids))
            batch_where_sampling = np.where(sample_ids)
            if use_next_input_layer:
                self.assertAllClose(
                    sess_results["step_next_inputs"][batch_where_sampling],
                    sess_results["output_after_next_input_layer"]
                    [batch_where_sampling])
            else:
                self.assertAllClose(
                    sess_results["step_next_inputs"][batch_where_sampling],
                    sess_results["step_outputs"].
                    rnn_output[batch_where_sampling])
            self.assertAllClose(
                sess_results["step_next_inputs"][batch_where_not_sampling],
                np.squeeze(inputs[batch_where_not_sampling, 1], axis=1))
Beispiel #16
0
  def testStepWithInferenceHelperCategorical(self):
    batch_size = 5
    vocabulary_size = 7
    cell_depth = vocabulary_size
    start_token = 0
    end_token = 6

    start_inputs = array_ops.one_hot(
        np.ones(batch_size) * start_token,
        vocabulary_size)

    # The sample function samples categorically from the logits.
    sample_fn = lambda x: helper_py.categorical_sample(logits=x)
    # The next inputs are a one-hot encoding of the sampled labels.
    next_inputs_fn = (
        lambda x: array_ops.one_hot(x, vocabulary_size, dtype=dtypes.float32))
    end_fn = lambda sample_ids: math_ops.equal(sample_ids, end_token)

    with self.session(use_gpu=True) as sess:
      with variable_scope.variable_scope(
          "testStepWithInferenceHelper",
          initializer=init_ops.constant_initializer(0.01)):
        cell = rnn_cell.LSTMCell(vocabulary_size)
        helper = helper_py.InferenceHelper(
            sample_fn, sample_shape=(), sample_dtype=dtypes.int32,
            start_inputs=start_inputs, end_fn=end_fn,
            next_inputs_fn=next_inputs_fn)
        my_decoder = basic_decoder.BasicDecoder(
            cell=cell,
            helper=helper,
            initial_state=cell.zero_state(
                dtype=dtypes.float32, batch_size=batch_size))
        output_size = my_decoder.output_size
        output_dtype = my_decoder.output_dtype
        self.assertEqual(
            basic_decoder.BasicDecoderOutput(cell_depth,
                                             tensor_shape.TensorShape([])),
            output_size)
        self.assertEqual(
            basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
            output_dtype)

        (first_finished, first_inputs, first_state) = my_decoder.initialize()
        (step_outputs, step_state, step_next_inputs,
         step_finished) = my_decoder.step(
             constant_op.constant(0), first_inputs, first_state)
        batch_size_t = my_decoder.batch_size

        self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
        self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
        self.assertTrue(
            isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
        self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
        self.assertEqual((batch_size,), step_outputs[1].get_shape())
        self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
        self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
        self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
        self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

        sess.run(variables.global_variables_initializer())
        sess_results = sess.run({
            "batch_size": batch_size_t,
            "first_finished": first_finished,
            "first_inputs": first_inputs,
            "first_state": first_state,
            "step_outputs": step_outputs,
            "step_state": step_state,
            "step_next_inputs": step_next_inputs,
            "step_finished": step_finished
        })

        sample_ids = sess_results["step_outputs"].sample_id
        self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
        expected_step_finished = (sample_ids == end_token)
        expected_step_next_inputs = np.zeros((batch_size, vocabulary_size))
        expected_step_next_inputs[np.arange(batch_size), sample_ids] = 1.0
        self.assertAllEqual(expected_step_finished,
                            sess_results["step_finished"])
        self.assertAllEqual(expected_step_next_inputs,
                            sess_results["step_next_inputs"])
    def testBahndahauNormalized(self):
        create_attention_mechanism = functools.partial(
            wrapper.BahdanauAttention,
            normalize=True,
            attention_r_initializer=2.0)

        array = np.array
        float32 = np.float32
        int32 = np.int32

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=array(
                [[[
                    1.72670335e-02, -5.83671592e-03, 6.38638902e-03,
                    -8.11776379e-04, 1.12681929e-03, -1.24236047e-02
                ],
                  [
                      1.75918192e-02, -5.73426578e-03, 6.29768707e-03,
                      -8.63141613e-04, 2.03352375e-03, -1.21420780e-02
                  ],
                  [
                      1.72424167e-02, -5.66471322e-03, 6.63427915e-03,
                      -6.23903936e-04, 1.68706616e-03, -1.22524602e-02
                  ]],
                 [[
                     1.79958157e-02, -9.80986748e-03, 4.73218597e-03,
                     -3.89962713e-03, 1.41502675e-02, -1.48344040e-02
                 ],
                  [
                      1.82184577e-02, -9.88379307e-03, 4.90130857e-03,
                      -3.91892251e-03, 1.36479288e-02, -1.53291579e-02
                  ],
                  [
                      1.83001235e-02, -1.00617753e-02, 4.97077405e-03,
                      -3.94908339e-03, 1.37211196e-02, -1.52311027e-02
                  ]],
                 [[
                     7.93476030e-03, -8.46967567e-03, -7.16930721e-04,
                     4.37953044e-04, 1.04503892e-03, -1.82424393e-02
                 ],
                  [
                      7.90629163e-03, -8.48819874e-03, -5.57833235e-04,
                      5.02390554e-04, 6.79406337e-04, -1.84837580e-02
                  ],
                  [
                      8.14734399e-03, -8.23053624e-03, -5.92814526e-04,
                      4.16347990e-04, 1.29250437e-03, -1.84548404e-02
                  ]],
                 [[
                     1.21026095e-02, -1.26739489e-02, 1.78718648e-04,
                     2.68748170e-03, 7.80996867e-03, -9.69076063e-04
                 ],
                  [
                      1.17978491e-02, -1.32678337e-02, 6.00410858e-05,
                      2.66301399e-03, 7.00691342e-03, -1.10030361e-03
                  ],
                  [
                      1.15651665e-02, -1.30795036e-02, -2.74205930e-04,
                      2.48012133e-03, 6.94250735e-03, -8.47495161e-04
                  ]],
                 [[
                     1.02377674e-02, -8.72955937e-03, 1.22555892e-03,
                     2.03830865e-03, 8.93574394e-03, -7.28237582e-03
                 ],
                  [
                      1.05115287e-02, -8.92531779e-03, 1.14568521e-03,
                      1.91635895e-03, 8.94328393e-03, -7.39541650e-03
                  ],
                  [
                      1.07398070e-02, -8.56867433e-03, 1.52354129e-03,
                      2.06834078e-03, 9.36511997e-03, -7.64556089e-03
                  ]]],
                dtype=float32),
            sample_id=array(
                [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
                dtype=int32))

        expected_final_state = wrapper.DynamicAttentionWrapperState(
            cell_state=core_rnn_cell.LSTMStateTuple(
                c=array([[
                    -0.02209264, -0.00794879, -0.00157153, 0.01614309,
                    -0.01383773, -0.00750943, -0.00824213, -0.01210296,
                    0.01794949
                ],
                         [
                             0.01726926, -0.01418139, -0.0040099, 0.0319339,
                             -0.03545783, -0.02142831, -0.00609501,
                             -0.00195033, -0.01938949
                         ],
                         [
                             -0.01159083, 0.0087524, -0.01639001, -0.01400012,
                             0.01342422, -0.01041037, 0.00620991, -0.00960796,
                             -0.00650131
                         ],
                         [
                             -0.04763237, -0.01192762, -0.00019377, 0.04103839,
                             -0.00138058, 0.02126443, -0.02793917, -0.05467755,
                             -0.02912025
                         ],
                         [
                             0.02241185, -0.00141741, 0.01911988, 0.00547728,
                             -0.01280068, -0.00307024, -0.00494239, 0.02169247,
                             0.01631995
                         ]],
                        dtype=float32),
                h=array([[
                    -1.10821165e-02, -3.92766716e-03, -7.99638336e-04,
                    7.92923011e-03, -7.04019284e-03, -3.77124036e-03,
                    -4.19876305e-03, -6.17261464e-03, 8.95325281e-03
                ],
                         [
                             8.60597286e-03, -7.16368994e-03, -1.94644753e-03,
                             1.62479617e-02, -1.76739115e-02, -1.06403306e-02,
                             -3.01484042e-03, -9.74688213e-04, -9.96260438e-03
                         ],
                         [
                             -5.78098884e-03, 4.48751403e-03, -8.12216662e-03,
                             -6.94991415e-03, 6.72604749e-03, -5.10144979e-03,
                             3.08637507e-03, -4.71517537e-03, -3.20256175e-03
                         ],
                         [
                             -2.38018110e-02, -5.89307398e-03, -9.74484938e-05,
                             2.01694984e-02, -6.82370039e-04, 1.07099237e-02,
                             -1.42087601e-02, -2.70793457e-02, -1.44684138e-02
                         ],
                         [
                             1.11825848e-02, -6.99267141e-04, 9.82748345e-03,
                             2.74566701e-03, -6.56377291e-03, -1.53681310e-03,
                             -2.48806458e-03, 1.10462429e-02, 7.97568541e-03
                         ]],
                        dtype=float32)),
            attention=array([[
                0.01724242, -0.00566471, 0.00663428, -0.0006239, 0.00168707,
                -0.01225246
            ],
                             [
                                 0.01830012, -0.01006178, 0.00497077,
                                 -0.00394908, 0.01372112, -0.0152311
                             ],
                             [
                                 0.00814734, -0.00823054, -0.00059281,
                                 0.00041635, 0.0012925, -0.01845484
                             ],
                             [
                                 0.01156517, -0.0130795, -0.00027421,
                                 0.00248012, 0.00694251, -0.0008475
                             ],
                             [
                                 0.01073981, -0.00856867, 0.00152354,
                                 0.00206834, 0.00936512, -0.00764556
                             ]],
                            dtype=float32))

        self._testWithAttention(create_attention_mechanism,
                                expected_final_output, expected_final_state)
Beispiel #18
0
  def testStepWithTrainingHelper(self):
    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    cell_depth = 10

    with self.test_session() as sess:
      inputs = np.random.randn(batch_size, max_time,
                               input_depth).astype(np.float32)
      cell = core_rnn_cell.LSTMCell(cell_depth)
      helper = helper_py.TrainingHelper(
          inputs, sequence_length, time_major=False)
      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dtypes.float32, batch_size=batch_size))
      output_size = my_decoder.output_size
      output_dtype = my_decoder.output_dtype
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(cell_depth,
                                           tensor_shape.TensorShape([])),
          output_size)
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
          output_dtype)

      (first_finished, first_inputs, first_state) = my_decoder.initialize()
      (step_outputs, step_state, step_next_inputs,
       step_finished) = my_decoder.step(
           constant_op.constant(0), first_inputs, first_state)
      batch_size_t = my_decoder.batch_size

      self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
      self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
      self.assertTrue(
          isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
      self.assertEqual((batch_size,), step_outputs[1].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          "batch_size": batch_size_t,
          "first_finished": first_finished,
          "first_inputs": first_inputs,
          "first_state": first_state,
          "step_outputs": step_outputs,
          "step_state": step_state,
          "step_next_inputs": step_next_inputs,
          "step_finished": step_finished
      })

      self.assertAllEqual([False, False, False, False, True],
                          sess_results["first_finished"])
      self.assertAllEqual([False, False, False, True, True],
                          sess_results["step_finished"])
      self.assertAllEqual(
          np.argmax(sess_results["step_outputs"].rnn_output, -1),
          sess_results["step_outputs"].sample_id)
    def testLuongNormalized(self):
        create_attention_mechanism = functools.partial(
            wrapper.LuongAttention,
            normalize=True,
            attention_r_initializer=2.0)

        array = np.array
        float32 = np.float32
        int32 = np.int32

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=array(
                [[[
                    1.23956744e-02, -6.88115368e-03, 3.15234554e-03,
                    -1.97300944e-03, 4.79680905e-03, -1.38076628e-02
                ],
                  [
                      1.28376717e-02, -6.78718928e-03, 3.07988771e-03,
                      -2.03956687e-03, 5.68403490e-03, -1.35601182e-02
                  ],
                  [
                      1.23463338e-02, -6.76322030e-03, 3.28891934e-03,
                      -1.86874042e-03, 5.47897862e-03, -1.37654068e-02
                  ]],
                 [[
                     1.54412268e-02, -1.07613346e-02, 4.43824846e-03,
                     -8.81063985e-04, 1.26828086e-02, -1.21067995e-02
                 ],
                  [
                      1.57206059e-02, -1.08218864e-02, 4.61952807e-03,
                      -9.61483689e-04, 1.22140013e-02, -1.26614980e-02
                  ],
                  [
                      1.57821011e-02, -1.09842420e-02, 4.66934917e-03,
                      -9.85997496e-04, 1.22719472e-02, -1.25438003e-02
                  ]],
                 [[
                     9.27361846e-03, -9.66077764e-03, -9.69522633e-04,
                     1.48308463e-05, 3.88664147e-03, -1.64083000e-02
                 ],
                  [
                      9.26287938e-03, -9.74234194e-03, -8.32488062e-04,
                      5.83778601e-05, 3.52663640e-03, -1.66827720e-02
                  ],
                  [
                      9.50474478e-03, -9.49789397e-03, -8.71829456e-04,
                      -3.09986062e-05, 4.13423358e-03, -1.66635048e-02
                  ]],
                 [[
                     1.21398102e-02, -1.27454493e-02, 1.57688977e-04,
                     2.70034792e-03, 7.79653806e-03, -8.36936757e-04
                 ],
                  [
                      1.18234595e-02, -1.33170560e-02, 4.55579720e-05,
                      2.67185434e-03, 6.99766818e-03, -1.00935437e-03
                  ],
                  [
                      1.16009805e-02, -1.31483339e-02, -2.94458936e-04,
                      2.49248254e-03, 6.92958105e-03, -7.20315147e-04
                  ]],
                 [[
                     1.02377674e-02, -8.72955937e-03, 1.22555892e-03,
                     2.03830865e-03, 8.93574394e-03, -7.28237582e-03
                 ],
                  [
                      1.05115287e-02, -8.92531779e-03, 1.14568521e-03,
                      1.91635895e-03, 8.94328393e-03, -7.39541650e-03
                  ],
                  [
                      1.07398070e-02, -8.56867433e-03, 1.52354129e-03,
                      2.06834078e-03, 9.36511997e-03, -7.64556089e-03
                  ]]],
                dtype=float32),
            sample_id=array(
                [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
                dtype=int32))
        expected_final_state = wrapper.DynamicAttentionWrapperState(
            cell_state=core_rnn_cell.LSTMStateTuple(
                c=array([[
                    -0.02204949, -0.00805957, -0.001603, 0.01609283,
                    -0.01380462, -0.0074945, -0.00816895, -0.01210009,
                    0.01795324
                ],
                         [
                             0.01727016, -0.01420708, -0.00399973, 0.03195432,
                             -0.03547529, -0.02138673, -0.00610332,
                             -0.00191565, -0.01937822
                         ],
                         [
                             -0.01160676, 0.00876512, -0.01641791, -0.01400807,
                             0.01347767, -0.01036341, 0.00627499, -0.00963627,
                             -0.00650573
                         ],
                         [
                             -0.04763342, -0.01192671, -0.00019402, 0.04103871,
                             -0.00138017, 0.02126611, -0.02793773, -0.05467714,
                             -0.02912043
                         ],
                         [
                             0.02241185, -0.00141741, 0.01911988, 0.00547728,
                             -0.01280068, -0.00307024, -0.00494239, 0.02169247,
                             0.01631995
                         ]],
                        dtype=float32),
                h=array([[
                    -1.10610286e-02, -3.98253463e-03, -8.15684092e-04,
                    7.90454168e-03, -7.02364743e-03, -3.76377185e-03,
                    -4.16135695e-03, -6.17104582e-03, 8.95532966e-03
                ],
                         [
                             8.60653073e-03, -7.17685232e-03, -1.94147974e-03,
                             1.62585936e-02, -1.76823437e-02, -1.06195193e-02,
                             -3.01911240e-03, -9.57308919e-04, -9.95720550e-03
                         ],
                         [
                             -5.78888878e-03, 4.49400023e-03, -8.13617278e-03,
                             -6.95386063e-03, 6.75271638e-03, -5.07823005e-03,
                             3.11873178e-03, -4.72912844e-03, -3.20472987e-03
                         ],
                         [
                             -2.38023344e-02, -5.89262368e-03, -9.75721487e-05,
                             2.01696623e-02, -6.82163402e-04, 1.07107637e-02,
                             -1.42080421e-02, -2.70791352e-02, -1.44685050e-02
                         ],
                         [
                             1.11825848e-02, -6.99267141e-04, 9.82748345e-03,
                             2.74566701e-03, -6.56377291e-03, -1.53681310e-03,
                             -2.48806458e-03, 1.10462429e-02, 7.97568541e-03
                         ]],
                        dtype=float32)),
            attention=array(
                [[
                    1.23463338e-02, -6.76322030e-03, 3.28891934e-03,
                    -1.86874042e-03, 5.47897862e-03, -1.37654068e-02
                ],
                 [
                     1.57821011e-02, -1.09842420e-02, 4.66934917e-03,
                     -9.85997496e-04, 1.22719472e-02, -1.25438003e-02
                 ],
                 [
                     9.50474478e-03, -9.49789397e-03, -8.71829456e-04,
                     -3.09986062e-05, 4.13423358e-03, -1.66635048e-02
                 ],
                 [
                     1.16009805e-02, -1.31483339e-02, -2.94458936e-04,
                     2.49248254e-03, 6.92958105e-03, -7.20315147e-04
                 ],
                 [
                     1.07398070e-02, -8.56867433e-03, 1.52354129e-03,
                     2.06834078e-03, 9.36511997e-03, -7.64556089e-03
                 ]],
                dtype=float32))
        self._testWithAttention(create_attention_mechanism,
                                expected_final_output,
                                expected_final_state,
                                attention_mechanism_depth=9)
    def testLuongNotNormalized(self):
        create_attention_mechanism = wrapper.LuongAttention

        array = np.array
        float32 = np.float32
        int32 = np.int32

        expected_final_output = basic_decoder.BasicDecoderOutput(
            rnn_output=array(
                [[[
                    1.23641128e-02, -6.82715839e-03, 3.24165262e-03,
                    -1.90772023e-03, 4.69654519e-03, -1.37025211e-02
                ],
                  [
                      1.29463980e-02, -6.79699238e-03, 3.10124992e-03,
                      -2.02869414e-03, 5.66399656e-03, -1.35517996e-02
                  ],
                  [
                      1.22659411e-02, -6.81970268e-03, 3.15135531e-03,
                      -1.96937821e-03, 5.62768336e-03, -1.39173865e-02
                  ]],
                 [[
                     1.53944232e-02, -1.07725551e-02, 4.42822604e-03,
                     -8.30623554e-04, 1.26549732e-02, -1.20573286e-02
                 ],
                  [
                      1.57453734e-02, -1.08157266e-02, 4.62466478e-03,
                      -9.88351414e-04, 1.22286947e-02, -1.26876952e-02
                  ],
                  [
                      1.57857724e-02, -1.09536834e-02, 4.64798324e-03,
                      -1.01319887e-03, 1.22695938e-02, -1.25500849e-02
                  ]],
                 [[
                     9.23123397e-03, -9.42669343e-03, -9.09919385e-04,
                     6.09827694e-05, 3.90436035e-03, -1.63374804e-02
                 ],
                  [
                      9.22935922e-03, -9.57853813e-03, -7.92966573e-04,
                      8.89014918e-05, 3.52671882e-03, -1.66499857e-02
                  ],
                  [
                      9.49526206e-03, -9.39475093e-03, -8.49372707e-04,
                      -1.72815053e-05, 4.16132808e-03, -1.66336838e-02
                  ]],
                 [[
                     1.21248290e-02, -1.27166547e-02, 1.66158192e-04,
                     2.69516627e-03, 7.80194718e-03, -8.90152063e-04
                 ],
                  [
                      1.17861275e-02, -1.32453050e-02, 6.66640699e-05,
                      2.65894993e-03, 7.01114535e-03, -1.14195189e-03
                  ],
                  [
                      1.15833860e-02, -1.31145213e-02, -2.84505659e-04,
                      2.48642010e-03, 6.93593081e-03, -7.82784075e-04
                  ]],
                 [[
                     1.02377674e-02, -8.72955937e-03, 1.22555892e-03,
                     2.03830865e-03, 8.93574394e-03, -7.28237582e-03
                 ],
                  [
                      1.05115287e-02, -8.92531779e-03, 1.14568521e-03,
                      1.91635895e-03, 8.94328393e-03, -7.39541650e-03
                  ],
                  [
                      1.07398070e-02, -8.56867433e-03, 1.52354129e-03,
                      2.06834078e-03, 9.36511997e-03, -7.64556089e-03
                  ]]],
                dtype=float32),
            sample_id=array(
                [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
                dtype=int32))
        expected_final_state = wrapper.DynamicAttentionWrapperState(
            cell_state=core_rnn_cell.LSTMStateTuple(
                c=array([[
                    -0.02204997, -0.00805805, -0.00160245, 0.01609369,
                    -0.01380494, -0.00749439, -0.00817, -0.01209992, 0.01795316
                ],
                         [
                             0.01727016, -0.01420713, -0.00399972, 0.03195436,
                             -0.03547532, -0.02138666, -0.00610335,
                             -0.00191557, -0.01937821
                         ],
                         [
                             -0.01160429, 0.00876595, -0.01641685, -0.01400784,
                             0.01348004, -0.01036458, 0.00627241, -0.00963544,
                             -0.00650568
                         ],
                         [
                             -0.04763246, -0.01192755, -0.00019379, 0.04103841,
                             -0.00138055, 0.02126456, -0.02793905, -0.0546775,
                             -0.02912027
                         ],
                         [
                             0.02241185, -0.00141741, 0.01911988, 0.00547728,
                             -0.01280068, -0.00307024, -0.00494239, 0.02169247,
                             0.01631995
                         ]],
                        dtype=float32),
                h=array([[
                    -1.10612623e-02, -3.98178305e-03, -8.15406092e-04,
                    7.90496264e-03, -7.02379830e-03, -3.76371504e-03,
                    -4.16189339e-03, -6.17096573e-03, 8.95528216e-03
                ],
                         [
                             8.60652886e-03, -7.17687514e-03, -1.94147555e-03,
                             1.62586085e-02, -1.76823605e-02, -1.06194830e-02,
                             -3.01912241e-03, -9.57269047e-04, -9.95719433e-03
                         ],
                         [
                             -5.78764686e-03, 4.49441886e-03, -8.13564472e-03,
                             -6.95375400e-03, 6.75391173e-03, -5.07880514e-03,
                             3.11744539e-03, -4.72871540e-03, -3.20470310e-03
                         ],
                         [
                             -2.38018595e-02, -5.89303859e-03, -9.74571449e-05,
                             2.01695058e-02, -6.82353624e-04, 1.07099945e-02,
                             -1.42086931e-02, -2.70793252e-02, -1.44684194e-02
                         ],
                         [
                             1.11825848e-02, -6.99267141e-04, 9.82748345e-03,
                             2.74566701e-03, -6.56377291e-03, -1.53681310e-03,
                             -2.48806458e-03, 1.10462429e-02, 7.97568541e-03
                         ]],
                        dtype=float32)),
            attention=array(
                [[
                    1.22659411e-02, -6.81970268e-03, 3.15135531e-03,
                    -1.96937821e-03, 5.62768336e-03, -1.39173865e-02
                ],
                 [
                     1.57857724e-02, -1.09536834e-02, 4.64798324e-03,
                     -1.01319887e-03, 1.22695938e-02, -1.25500849e-02
                 ],
                 [
                     9.49526206e-03, -9.39475093e-03, -8.49372707e-04,
                     -1.72815053e-05, 4.16132808e-03, -1.66336838e-02
                 ],
                 [
                     1.15833860e-02, -1.31145213e-02, -2.84505659e-04,
                     2.48642010e-03, 6.93593081e-03, -7.82784075e-04
                 ],
                 [
                     1.07398070e-02, -8.56867433e-03, 1.52354129e-03,
                     2.06834078e-03, 9.36511997e-03, -7.64556089e-03
                 ]],
                dtype=float32))

        self._testWithAttention(create_attention_mechanism,
                                expected_final_output,
                                expected_final_state,
                                attention_mechanism_depth=9)
Beispiel #21
0
  def _testStepWithScheduledOutputTrainingHelper(
      self, sampling_probability, use_next_inputs_fn, use_auxiliary_inputs):
    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    cell_depth = input_depth
    if use_auxiliary_inputs:
      auxiliary_input_depth = 4
      auxiliary_inputs = np.random.randn(
          batch_size, max_time, auxiliary_input_depth).astype(np.float32)
    else:
      auxiliary_inputs = None

    with self.session(use_gpu=True) as sess:
      inputs = np.random.randn(batch_size, max_time,
                               input_depth).astype(np.float32)
      cell = rnn_cell.LSTMCell(cell_depth)
      sampling_probability = constant_op.constant(sampling_probability)

      if use_next_inputs_fn:
        def next_inputs_fn(outputs):
          # Use deterministic function for test.
          samples = math_ops.argmax(outputs, axis=1)
          return array_ops.one_hot(samples, cell_depth, dtype=dtypes.float32)
      else:
        next_inputs_fn = None

      helper = helper_py.ScheduledOutputTrainingHelper(
          inputs=inputs,
          sequence_length=sequence_length,
          sampling_probability=sampling_probability,
          time_major=False,
          next_inputs_fn=next_inputs_fn,
          auxiliary_inputs=auxiliary_inputs)

      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dtypes.float32, batch_size=batch_size))

      output_size = my_decoder.output_size
      output_dtype = my_decoder.output_dtype
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(cell_depth,
                                           tensor_shape.TensorShape([])),
          output_size)
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
          output_dtype)

      (first_finished, first_inputs, first_state) = my_decoder.initialize()
      (step_outputs, step_state, step_next_inputs,
       step_finished) = my_decoder.step(
           constant_op.constant(0), first_inputs, first_state)

      if use_next_inputs_fn:
        output_after_next_inputs_fn = next_inputs_fn(step_outputs.rnn_output)

      batch_size_t = my_decoder.batch_size

      self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(
          isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
      self.assertEqual((batch_size,), step_outputs[1].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

      sess.run(variables.global_variables_initializer())

      fetches = {
          "batch_size": batch_size_t,
          "first_finished": first_finished,
          "first_inputs": first_inputs,
          "first_state": first_state,
          "step_outputs": step_outputs,
          "step_state": step_state,
          "step_next_inputs": step_next_inputs,
          "step_finished": step_finished
      }
      if use_next_inputs_fn:
        fetches["output_after_next_inputs_fn"] = output_after_next_inputs_fn

      sess_results = sess.run(fetches)

      self.assertAllEqual([False, False, False, False, True],
                          sess_results["first_finished"])
      self.assertAllEqual([False, False, False, True, True],
                          sess_results["step_finished"])

      sample_ids = sess_results["step_outputs"].sample_id
      self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
      batch_where_not_sampling = np.where(np.logical_not(sample_ids))
      batch_where_sampling = np.where(sample_ids)

      auxiliary_inputs_to_concat = (
          auxiliary_inputs[:, 1] if use_auxiliary_inputs else
          np.array([]).reshape(batch_size, 0).astype(np.float32))

      expected_next_sampling_inputs = np.concatenate(
          (sess_results["output_after_next_inputs_fn"][batch_where_sampling]
           if use_next_inputs_fn else
           sess_results["step_outputs"].rnn_output[batch_where_sampling],
           auxiliary_inputs_to_concat[batch_where_sampling]),
          axis=-1)
      self.assertAllClose(
          sess_results["step_next_inputs"][batch_where_sampling],
          expected_next_sampling_inputs)

      self.assertAllClose(
          sess_results["step_next_inputs"][batch_where_not_sampling],
          np.concatenate(
              (np.squeeze(inputs[batch_where_not_sampling, 1], axis=0),
               auxiliary_inputs_to_concat[batch_where_not_sampling]),
              axis=-1))
    def testBahndahauNotNormalized(self):
        create_attention_mechanism = wrapper.BahdanauAttention

        array = np.array
        float32 = np.float32
        int32 = np.int32

        expected_final_outputs = basic_decoder.BasicDecoderOutput(
            rnn_output=array(
                [[[
                    1.25166783e-02, -6.88887993e-03, 3.17239435e-03,
                    -1.98234897e-03, 4.77387803e-03, -1.38330357e-02
                ],
                  [
                      1.28883058e-02, -6.76271692e-03, 3.13419267e-03,
                      -2.02183682e-03, 5.62057737e-03, -1.35373026e-02
                  ],
                  [
                      1.24917831e-02, -6.71574520e-03, 3.42238229e-03,
                      -1.79501204e-03, 5.33161033e-03, -1.36620644e-02
                  ]],
                 [[
                     1.55150667e-02, -1.07274549e-02, 4.44198400e-03,
                     -9.73310322e-04, 1.27242506e-02, -1.21861566e-02
                 ],
                  [
                      1.57585666e-02, -1.07965544e-02, 4.61554807e-03,
                      -1.01510016e-03, 1.22341057e-02, -1.27029382e-02
                  ],
                  [
                      1.58304181e-02, -1.09712025e-02, 4.67861444e-03,
                      -1.03920139e-03, 1.23004699e-02, -1.25949886e-02
                  ]],
                 [[
                     9.26700700e-03, -9.75431874e-03, -9.95740294e-04,
                     -1.27463136e-06, 3.81659716e-03, -1.64887272e-02
                 ],
                  [
                      9.25191958e-03, -9.80092678e-03, -8.48566880e-04,
                      5.02091134e-05, 3.46567202e-03, -1.67435352e-02
                  ],
                  [
                      9.48173273e-03, -9.52653307e-03, -8.79382715e-04,
                      -3.07094306e-05, 4.05955408e-03, -1.67226996e-02
                  ]],
                 [[
                     1.21462569e-02, -1.27578378e-02, 1.54045003e-04,
                     2.70257704e-03, 7.79421115e-03, -8.14041123e-04
                 ],
                  [
                      1.18412934e-02, -1.33513296e-02, 3.54760559e-05,
                      2.67801876e-03, 6.99122995e-03, -9.46014654e-04
                  ],
                  [
                      1.16087487e-02, -1.31632648e-02, -2.98853614e-04,
                      2.49515846e-03, 6.92677684e-03, -6.92734495e-04
                  ]],
                 [[
                     1.02377674e-02, -8.72955937e-03, 1.22555892e-03,
                     2.03830865e-03, 8.93574394e-03, -7.28237582e-03
                 ],
                  [
                      1.05115287e-02, -8.92531779e-03, 1.14568521e-03,
                      1.91635895e-03, 8.94328393e-03, -7.39541650e-03
                  ],
                  [
                      1.07398070e-02, -8.56867433e-03, 1.52354129e-03,
                      2.06834078e-03, 9.36511997e-03, -7.64556089e-03
                  ]]],
                dtype=float32),
            sample_id=array(
                [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
                dtype=int32))

        expected_final_state = wrapper.DynamicAttentionWrapperState(
            cell_state=core_rnn_cell.LSTMStateTuple(
                c=array([[
                    -0.0220502, -0.008058, -0.00160266, 0.01609341,
                    -0.01380513, -0.00749483, -0.00816989, -0.01210028,
                    0.01795324
                ],
                         [
                             0.01727026, -0.0142065, -0.00399991, 0.03195379,
                             -0.03547479, -0.02138772, -0.00610318,
                             -0.00191625, -0.01937846
                         ],
                         [
                             -0.0116077, 0.00876439, -0.01641787, -0.01400803,
                             0.01347527, -0.01036386, 0.00627491, -0.0096361,
                             -0.00650565
                         ],
                         [
                             -0.04763387, -0.01192631, -0.00019412, 0.04103886,
                             -0.00137999, 0.02126684, -0.02793711, -0.05467696,
                             -0.02912051
                         ],
                         [
                             0.02241185, -0.00141741, 0.01911988, 0.00547728,
                             -0.01280068, -0.00307024, -0.00494239, 0.02169247,
                             0.01631995
                         ]],
                        dtype=float32),
                h=array([[
                    -1.10613741e-02, -3.98175791e-03, -8.15514475e-04,
                    7.90482666e-03, -7.02390168e-03, -3.76394135e-03,
                    -4.16183751e-03, -6.17114361e-03, 8.95532221e-03
                ],
                         [
                             8.60657450e-03, -7.17655150e-03, -1.94156705e-03,
                             1.62583217e-02, -1.76821016e-02, -1.06200138e-02,
                             -3.01904045e-03, -9.57608980e-04, -9.95732192e-03
                         ],
                         [
                             -5.78935863e-03, 4.49362956e-03, -8.13615043e-03,
                             -6.95384294e-03, 6.75151078e-03, -5.07845683e-03,
                             3.11869266e-03, -4.72904649e-03, -3.20469099e-03
                         ],
                         [
                             -2.38025561e-02, -5.89242764e-03, -9.76260417e-05,
                             2.01697368e-02, -6.82076614e-04, 1.07111251e-02,
                             -1.42077375e-02, -2.70790439e-02, -1.44685479e-02
                         ],
                         [
                             1.11825848e-02, -6.99267141e-04, 9.82748345e-03,
                             2.74566701e-03, -6.56377291e-03, -1.53681310e-03,
                             -2.48806458e-03, 1.10462429e-02, 7.97568541e-03
                         ]],
                        dtype=float32)),
            attention=array(
                [[
                    1.24917831e-02, -6.71574520e-03, 3.42238229e-03,
                    -1.79501204e-03, 5.33161033e-03, -1.36620644e-02
                ],
                 [
                     1.58304181e-02, -1.09712025e-02, 4.67861444e-03,
                     -1.03920139e-03, 1.23004699e-02, -1.25949886e-02
                 ],
                 [
                     9.48173273e-03, -9.52653307e-03, -8.79382715e-04,
                     -3.07094306e-05, 4.05955408e-03, -1.67226996e-02
                 ],
                 [
                     1.16087487e-02, -1.31632648e-02, -2.98853614e-04,
                     2.49515846e-03, 6.92677684e-03, -6.92734495e-04
                 ],
                 [
                     1.07398070e-02, -8.56867433e-03, 1.52354129e-03,
                     2.06834078e-03, 9.36511997e-03, -7.64556089e-03
                 ]],
                dtype=float32))
        self._testWithAttention(create_attention_mechanism,
                                expected_final_outputs, expected_final_state)
Beispiel #23
0
  def testStepWithScheduledEmbeddingTrainingHelper(self):
    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    vocabulary_size = 10

    with self.session(use_gpu=True) as sess:
      inputs = np.random.randn(
          batch_size, max_time, input_depth).astype(np.float32)
      embeddings = np.random.randn(
          vocabulary_size, input_depth).astype(np.float32)
      half = constant_op.constant(0.5)
      cell = rnn_cell.LSTMCell(vocabulary_size)
      helper = helper_py.ScheduledEmbeddingTrainingHelper(
          inputs=inputs,
          sequence_length=sequence_length,
          embedding=embeddings,
          sampling_probability=half,
          time_major=False)
      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dtypes.float32, batch_size=batch_size))
      output_size = my_decoder.output_size
      output_dtype = my_decoder.output_dtype
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(vocabulary_size,
                                           tensor_shape.TensorShape([])),
          output_size)
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
          output_dtype)

      (first_finished, first_inputs, first_state) = my_decoder.initialize()
      (step_outputs, step_state, step_next_inputs,
       step_finished) = my_decoder.step(
           constant_op.constant(0), first_inputs, first_state)
      batch_size_t = my_decoder.batch_size

      self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(
          isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual((batch_size, vocabulary_size),
                       step_outputs[0].get_shape())
      self.assertEqual((batch_size,), step_outputs[1].get_shape())
      self.assertEqual((batch_size, vocabulary_size),
                       first_state[0].get_shape())
      self.assertEqual((batch_size, vocabulary_size),
                       first_state[1].get_shape())
      self.assertEqual((batch_size, vocabulary_size),
                       step_state[0].get_shape())
      self.assertEqual((batch_size, vocabulary_size),
                       step_state[1].get_shape())
      self.assertEqual((batch_size, input_depth),
                       step_next_inputs.get_shape())

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          "batch_size": batch_size_t,
          "first_finished": first_finished,
          "first_inputs": first_inputs,
          "first_state": first_state,
          "step_outputs": step_outputs,
          "step_state": step_state,
          "step_next_inputs": step_next_inputs,
          "step_finished": step_finished
      })

      self.assertAllEqual([False, False, False, False, True],
                          sess_results["first_finished"])
      self.assertAllEqual([False, False, False, True, True],
                          sess_results["step_finished"])
      sample_ids = sess_results["step_outputs"].sample_id
      self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
      batch_where_not_sampling = np.where(sample_ids == -1)
      batch_where_sampling = np.where(sample_ids > -1)
      self.assertAllClose(
          sess_results["step_next_inputs"][batch_where_sampling],
          embeddings[sample_ids[batch_where_sampling]])
      self.assertAllClose(
          sess_results["step_next_inputs"][batch_where_not_sampling],
          np.squeeze(inputs[batch_where_not_sampling, 1]))
Beispiel #24
0
    def step(self, time, inputs, state, name=None):
        """Perform a decoding step.

    Args:
      time: scalar `int32` tensor.
      inputs: A (structure of) input tensors.
      state: A (structure of) state tensors and TensorArrays.
      name: Name scope for any created operations.

    Returns:
      `(outputs, next_state, next_inputs, finished)`.
    """
        with ops.name_scope(name, 'TrieSamplerDecoderStep',
                            (time, inputs, state)):
            if _is_attention_state(state):
                cell_outputs, cell_state = self._cell(inputs, state)
                state_trie_keys = state.trie_keys
                state_trie_exclude = state.trie_exclude
            elif _is_gnmt_state(state):
                cell_outputs, cell_state = self._cell(inputs, state)
                state_trie_keys = state[0].trie_keys
                state_trie_exclude = state[0].trie_exclude
            else:
                cell_outputs, cell_state = self._cell(inputs, state.cell_state)
                state_trie_keys = state.trie_keys
                state_trie_exclude = state.trie_exclude
            if self._output_layer is not None:
                cell_outputs = self._output_layer(cell_outputs)

            cell_outputs_shape = cell_outputs.get_shape()
            cell_outputs = tf.py_func(
                _trie_scores_py_func(self.trie, beam_search=False),
                [cell_outputs, state_trie_keys, state_trie_exclude],
                tf.float32,
                stateful=False)
            cell_outputs.set_shape(cell_outputs_shape)

            sample_ids = self._helper.sample(time=time,
                                             outputs=cell_outputs,
                                             state=cell_state)
            (finished, next_inputs,
             next_cell_state) = self._helper.next_inputs(time=time,
                                                         outputs=cell_outputs,
                                                         state=cell_state,
                                                         sample_ids=sample_ids)

        trie_keys = tf.py_func(_amend_trie_keys_py_func(beam_search=False),
                               [state_trie_keys, sample_ids],
                               tf.string,
                               stateful=False)
        trie_keys.set_shape(state_trie_keys.get_shape())

        if _is_attention_state(next_cell_state):
            next_state = TrieSamplerAttentionState(
                *next_cell_state,
                trie_keys=trie_keys,
                trie_exclude=state_trie_exclude)
        elif _is_gnmt_state(next_cell_state):
            next_state = (TrieSamplerAttentionState(
                *next_cell_state[0],
                trie_keys=trie_keys,
                trie_exclude=state_trie_exclude), ) + next_cell_state[1:]
        else:
            next_state = TrieSamplerState(cell_state=next_cell_state,
                                          trie_keys=trie_keys,
                                          trie_exclude=state_trie_exclude)

        outputs = basic_decoder.BasicDecoderOutput(cell_outputs, sample_ids)
        return outputs, next_state, next_inputs, finished
Beispiel #25
0
  def _testStepWithTrainingHelper(self, use_output_layer):
    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    cell_depth = 10
    output_layer_depth = 3

    with self.session(use_gpu=True) as sess:
      inputs = np.random.randn(batch_size, max_time,
                               input_depth).astype(np.float32)
      cell = rnn_cell.LSTMCell(cell_depth)
      helper = helper_py.TrainingHelper(
          inputs, sequence_length, time_major=False)
      if use_output_layer:
        output_layer = layers_core.Dense(output_layer_depth, use_bias=False)
        expected_output_depth = output_layer_depth
      else:
        output_layer = None
        expected_output_depth = cell_depth
      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dtypes.float32, batch_size=batch_size),
          output_layer=output_layer)
      output_size = my_decoder.output_size
      output_dtype = my_decoder.output_dtype
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(expected_output_depth,
                                           tensor_shape.TensorShape([])),
          output_size)
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
          output_dtype)

      (first_finished, first_inputs, first_state) = my_decoder.initialize()
      (step_outputs, step_state, step_next_inputs,
       step_finished) = my_decoder.step(
           constant_op.constant(0), first_inputs, first_state)
      batch_size_t = my_decoder.batch_size

      self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(
          isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual((batch_size, expected_output_depth),
                       step_outputs[0].get_shape())
      self.assertEqual((batch_size,), step_outputs[1].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

      if use_output_layer:
        # The output layer was accessed
        self.assertEqual(len(output_layer.variables), 1)

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          "batch_size": batch_size_t,
          "first_finished": first_finished,
          "first_inputs": first_inputs,
          "first_state": first_state,
          "step_outputs": step_outputs,
          "step_state": step_state,
          "step_next_inputs": step_next_inputs,
          "step_finished": step_finished
      })

      self.assertAllEqual([False, False, False, False, True],
                          sess_results["first_finished"])
      self.assertAllEqual([False, False, False, True, True],
                          sess_results["step_finished"])
      self.assertEqual(output_dtype.sample_id,
                       sess_results["step_outputs"].sample_id.dtype)
      self.assertAllEqual(
          np.argmax(sess_results["step_outputs"].rnn_output, -1),
          sess_results["step_outputs"].sample_id)
Beispiel #26
0
    def testStepWithGreedyEmbeddingHelper(self):
        batch_size = 5
        vocabulary_size = 7
        cell_depth = vocabulary_size  # cell's logits must match vocabulary size
        input_depth = 10
        start_tokens = [0] * batch_size
        end_token = 1

        with self.test_session() as sess:
            embeddings = np.random.randn(vocabulary_size,
                                         input_depth).astype(np.float32)
            cell = core_rnn_cell.LSTMCell(vocabulary_size)
            helper = helper_py.GreedyEmbeddingHelper(embeddings, start_tokens,
                                                     end_token)
            my_decoder = basic_decoder.BasicDecoder(
                cell=cell,
                helper=helper,
                initial_state=cell.zero_state(dtype=dtypes.float32,
                                              batch_size=batch_size))
            output_size = my_decoder.output_size
            output_dtype = my_decoder.output_dtype
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(cell_depth,
                                                 tensor_shape.TensorShape([])),
                output_size)
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
                output_dtype)

            (first_finished, first_inputs,
             first_state) = my_decoder.initialize()
            (step_outputs, step_state, step_next_inputs,
             step_finished) = my_decoder.step(constant_op.constant(0),
                                              first_inputs, first_state)
            batch_size_t = my_decoder.batch_size

            self.assertTrue(
                isinstance(first_state, core_rnn_cell.LSTMStateTuple))
            self.assertTrue(
                isinstance(step_state, core_rnn_cell.LSTMStateTuple))
            self.assertTrue(
                isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
            self.assertEqual((batch_size, cell_depth),
                             step_outputs[0].get_shape())
            self.assertEqual((batch_size, ), step_outputs[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[1].get_shape())

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                "batch_size": batch_size_t,
                "first_finished": first_finished,
                "first_inputs": first_inputs,
                "first_state": first_state,
                "step_outputs": step_outputs,
                "step_state": step_state,
                "step_next_inputs": step_next_inputs,
                "step_finished": step_finished
            })

            expected_sample_ids = np.argmax(
                sess_results["step_outputs"].rnn_output, -1)
            expected_step_finished = (expected_sample_ids == end_token)
            expected_step_next_inputs = embeddings[expected_sample_ids]
            self.assertAllEqual([False, False, False, False, False],
                                sess_results["first_finished"])
            self.assertAllEqual(expected_step_finished,
                                sess_results["step_finished"])
            self.assertAllEqual(expected_sample_ids,
                                sess_results["step_outputs"].sample_id)
            self.assertAllEqual(expected_step_next_inputs,
                                sess_results["step_next_inputs"])
Beispiel #27
0
  def testStepWithInferenceHelperMultilabel(self):
    batch_size = 5
    vocabulary_size = 7
    cell_depth = vocabulary_size
    start_token = 0
    end_token = 6

    start_inputs = array_ops.one_hot(
        np.ones(batch_size) * start_token,
        vocabulary_size)

    # The sample function samples independent bernoullis from the logits.
    sample_fn = (
        lambda x: helper_py.bernoulli_sample(logits=x, dtype=dtypes.bool))
    # The next inputs are a one-hot encoding of the sampled labels.
    next_inputs_fn = math_ops.to_float
    end_fn = lambda sample_ids: sample_ids[:, end_token]

    with self.session(use_gpu=True) as sess:
      with variable_scope.variable_scope(
          "testStepWithInferenceHelper",
          initializer=init_ops.constant_initializer(0.01)):
        cell = rnn_cell.LSTMCell(vocabulary_size)
        helper = helper_py.InferenceHelper(
            sample_fn, sample_shape=[cell_depth], sample_dtype=dtypes.bool,
            start_inputs=start_inputs, end_fn=end_fn,
            next_inputs_fn=next_inputs_fn)
        my_decoder = basic_decoder.BasicDecoder(
            cell=cell,
            helper=helper,
            initial_state=cell.zero_state(
                dtype=dtypes.float32, batch_size=batch_size))
        output_size = my_decoder.output_size
        output_dtype = my_decoder.output_dtype
        self.assertEqual(
            basic_decoder.BasicDecoderOutput(cell_depth, cell_depth),
            output_size)
        self.assertEqual(
            basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.bool),
            output_dtype)

        (first_finished, first_inputs, first_state) = my_decoder.initialize()
        (step_outputs, step_state, step_next_inputs,
         step_finished) = my_decoder.step(
             constant_op.constant(0), first_inputs, first_state)
        batch_size_t = my_decoder.batch_size

        self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
        self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
        self.assertTrue(
            isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
        self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
        self.assertEqual((batch_size, cell_depth), step_outputs[1].get_shape())
        self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
        self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
        self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
        self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

        sess.run(variables.global_variables_initializer())
        sess_results = sess.run({
            "batch_size": batch_size_t,
            "first_finished": first_finished,
            "first_inputs": first_inputs,
            "first_state": first_state,
            "step_outputs": step_outputs,
            "step_state": step_state,
            "step_next_inputs": step_next_inputs,
            "step_finished": step_finished
        })

        sample_ids = sess_results["step_outputs"].sample_id
        self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
        expected_step_finished = sample_ids[:, end_token]
        expected_step_next_inputs = sample_ids.astype(np.float32)
        self.assertAllEqual(expected_step_finished,
                            sess_results["step_finished"])
        self.assertAllEqual(expected_step_next_inputs,
                            sess_results["step_next_inputs"])
    def DISABLED_testStepWithGreedyEmbeddingHelper(self):
        batch_size = 5
        vocabulary_size = 7
        cell_depth = vocabulary_size  # cell's logits must match vocabulary size
        input_depth = 10
        start_tokens = np.random.randint(0, vocabulary_size, size=batch_size)
        end_token = 1

        with self.cached_session(use_gpu=True):
            embeddings = np.random.randn(vocabulary_size,
                                         input_depth).astype(np.float32)
            embeddings_t = constant_op.constant(embeddings)
            cell = rnn_cell.LSTMCell(vocabulary_size)
            sampler = sampler_py.GreedyEmbeddingSampler()
            initial_state = cell.zero_state(dtype=dtypes.float32,
                                            batch_size=batch_size)
            my_decoder = basic_decoder.BasicDecoderV2(cell=cell,
                                                      sampler=sampler)
            (first_finished, first_inputs,
             first_state) = my_decoder.initialize(embeddings_t,
                                                  start_tokens=start_tokens,
                                                  end_token=end_token,
                                                  initial_state=initial_state)
            output_size = my_decoder.output_size
            output_dtype = my_decoder.output_dtype
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(cell_depth,
                                                 tensor_shape.TensorShape([])),
                output_size)
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
                output_dtype)

            (step_outputs, step_state, step_next_inputs,
             step_finished) = my_decoder.step(constant_op.constant(0),
                                              first_inputs, first_state)
            batch_size_t = my_decoder.batch_size

            self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
            self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
            self.assertTrue(
                isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
            self.assertEqual((batch_size, cell_depth),
                             step_outputs[0].get_shape())
            self.assertEqual((batch_size, ), step_outputs[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[1].get_shape())

            self.evaluate(variables.global_variables_initializer())
            eval_result = self.evaluate({
                "batch_size": batch_size_t,
                "first_finished": first_finished,
                "first_inputs": first_inputs,
                "first_state": first_state,
                "step_outputs": step_outputs,
                "step_state": step_state,
                "step_next_inputs": step_next_inputs,
                "step_finished": step_finished
            })

            expected_sample_ids = np.argmax(
                eval_result["step_outputs"].rnn_output, -1)
            expected_step_finished = (expected_sample_ids == end_token)
            expected_step_next_inputs = embeddings[expected_sample_ids]
            self.assertAllEqual([False, False, False, False, False],
                                eval_result["first_finished"])
            self.assertAllEqual(expected_step_finished,
                                eval_result["step_finished"])
            self.assertEqual(output_dtype.sample_id,
                             eval_result["step_outputs"].sample_id.dtype)
            self.assertAllEqual(expected_sample_ids,
                                eval_result["step_outputs"].sample_id)
            self.assertAllEqual(expected_step_next_inputs,
                                eval_result["step_next_inputs"])