예제 #1
0
    def testLuongScaledDType(self, dtype):
        # Test case for GitHub issue 18099
        encoder_outputs = self.encoder_outputs.astype(dtype)
        decoder_inputs = self.decoder_inputs.astype(dtype)
        attention_mechanism = wrapper.LuongAttentionV2(
            units=self.units,
            memory=encoder_outputs,
            memory_sequence_length=self.encoder_sequence_length,
            scale=True,
            dtype=dtype,
        )
        cell = keras.layers.LSTMCell(self.units,
                                     recurrent_activation="sigmoid",
                                     dtype=dtype)
        cell = wrapper.AttentionWrapper(cell, attention_mechanism, dtype=dtype)

        sampler = sampler_py.TrainingSampler()
        my_decoder = basic_decoder.BasicDecoderV2(cell=cell,
                                                  sampler=sampler,
                                                  dtype=dtype)

        final_outputs, final_state, _ = my_decoder(
            decoder_inputs,
            initial_state=cell.zero_state(dtype=dtype, batch_size=self.batch),
            sequence_length=self.decoder_sequence_length)
        self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput)
        self.assertEqual(final_outputs.rnn_output.dtype, dtype)
        self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
예제 #2
0
    def _testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN(
            self, use_sequence_length):
        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        cell_depth = 10
        max_out = max(sequence_length)

        with self.cached_session(use_gpu=True):
            inputs = np.random.randn(batch_size, max_time,
                                     input_depth).astype(np.float32)
            inputs = constant_op.constant(inputs)

            cell = rnn_cell.LSTMCell(cell_depth)
            zero_state = cell.zero_state(dtype=dtypes.float32,
                                         batch_size=batch_size)
            sampler = sampler_py.TrainingSampler()
            my_decoder = basic_decoder.BasicDecoderV2(
                cell=cell,
                sampler=sampler,
                impute_finished=use_sequence_length)

            final_decoder_outputs, final_decoder_state, _ = my_decoder(
                inputs,
                initial_state=zero_state,
                sequence_length=sequence_length)

            final_rnn_outputs, final_rnn_state = rnn.dynamic_rnn(
                cell,
                inputs,
                sequence_length=sequence_length
                if use_sequence_length else None,
                initial_state=zero_state)

            self.evaluate(variables.global_variables_initializer())
            eval_result = self.evaluate({
                "final_decoder_outputs": final_decoder_outputs,
                "final_decoder_state": final_decoder_state,
                "final_rnn_outputs": final_rnn_outputs,
                "final_rnn_state": final_rnn_state
            })

            # Decoder only runs out to max_out; ensure values are identical
            # to dynamic_rnn, which also zeros out outputs and passes along state.
            self.assertAllClose(
                eval_result["final_decoder_outputs"].rnn_output,
                eval_result["final_rnn_outputs"][:, 0:max_out, :])
            if use_sequence_length:
                self.assertAllClose(eval_result["final_decoder_state"],
                                    eval_result["final_rnn_state"])
예제 #3
0
  def testStepWithInferenceHelperMultilabel(self):
    batch_size = 5
    vocabulary_size = 7
    cell_depth = vocabulary_size
    start_token = 0
    end_token = 6

    start_inputs = array_ops.one_hot(
        np.ones(batch_size, dtype=np.int32) * start_token,
        vocabulary_size)

    # The sample function samples independent bernoullis from the logits.
    sample_fn = (
        lambda x: sampler_py.bernoulli_sample(logits=x, dtype=dtypes.bool))
    # The next inputs are a one-hot encoding of the sampled labels.
    next_inputs_fn = math_ops.to_float
    end_fn = lambda sample_ids: sample_ids[:, end_token]

    with self.cached_session(use_gpu=True):
      cell = rnn_cell.LSTMCell(vocabulary_size)
      sampler = sampler_py.InferenceSampler(
          sample_fn, sample_shape=[cell_depth], sample_dtype=dtypes.bool,
          end_fn=end_fn, next_inputs_fn=next_inputs_fn)
      initial_state = cell.zero_state(
          dtype=dtypes.float32, batch_size=batch_size)
      my_decoder = basic_decoder.BasicDecoderV2(
          cell=cell,
          sampler=sampler)
      (first_finished, first_inputs, first_state) = my_decoder.initialize(
          start_inputs, initial_state=initial_state)
      output_size = my_decoder.output_size
      output_dtype = my_decoder.output_dtype
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(cell_depth, cell_depth),
          output_size)
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.bool),
          output_dtype)

      (step_outputs, step_state, step_next_inputs,
       step_finished) = my_decoder.step(
           constant_op.constant(0), first_inputs, first_state)
      batch_size_t = my_decoder.batch_size

      self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(
          isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
      self.assertEqual((batch_size, cell_depth), step_outputs[1].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

      self.evaluate(variables.global_variables_initializer())
      eval_result = self.evaluate({
          "batch_size": batch_size_t,
          "first_finished": first_finished,
          "first_inputs": first_inputs,
          "first_state": first_state,
          "step_outputs": step_outputs,
          "step_state": step_state,
          "step_next_inputs": step_next_inputs,
          "step_finished": step_finished
      })

      sample_ids = eval_result["step_outputs"].sample_id
      self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
      expected_step_finished = sample_ids[:, end_token]
      expected_step_next_inputs = sample_ids.astype(np.float32)
      self.assertAllEqual(expected_step_finished,
                          eval_result["step_finished"])
      self.assertAllEqual(expected_step_next_inputs,
                          eval_result["step_next_inputs"])
예제 #4
0
  def testStepWithTrainingHelperOutputLayer(self, use_output_layer):
    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    cell_depth = 10
    output_layer_depth = 3

    with self.cached_session(use_gpu=True):
      inputs = np.random.randn(batch_size, max_time,
                               input_depth).astype(np.float32)
      input_t = constant_op.constant(inputs)
      cell = rnn_cell.LSTMCell(cell_depth)
      sampler = sampler_py.TrainingSampler(time_major=False)
      if use_output_layer:
        output_layer = layers_core.Dense(output_layer_depth, use_bias=False)
        expected_output_depth = output_layer_depth
      else:
        output_layer = None
        expected_output_depth = cell_depth
      initial_state = cell.zero_state(dtype=dtypes.float32,
                                      batch_size=batch_size)
      my_decoder = basic_decoder.BasicDecoderV2(
          cell=cell,
          sampler=sampler,
          output_layer=output_layer)

      (first_finished,
       first_inputs,
       first_state) = my_decoder.initialize(input_t,
                                            initial_state=initial_state,
                                            sequence_length=sequence_length)
      output_size = my_decoder.output_size
      output_dtype = my_decoder.output_dtype
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(expected_output_depth,
                                           tensor_shape.TensorShape([])),
          output_size)
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
          output_dtype)

      (step_outputs, step_state, step_next_inputs,
       step_finished) = my_decoder.step(
           constant_op.constant(0), first_inputs, first_state)
      batch_size_t = my_decoder.batch_size

      self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(
          isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual((batch_size, expected_output_depth),
                       step_outputs[0].get_shape())
      self.assertEqual((batch_size,), step_outputs[1].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

      if use_output_layer:
        # The output layer was accessed
        self.assertEqual(len(output_layer.variables), 1)

      self.evaluate(variables.global_variables_initializer())
      eval_result = self.evaluate({
          "batch_size": batch_size_t,
          "first_finished": first_finished,
          "first_inputs": first_inputs,
          "first_state": first_state,
          "step_outputs": step_outputs,
          "step_state": step_state,
          "step_next_inputs": step_next_inputs,
          "step_finished": step_finished
      })

      self.assertAllEqual([False, False, False, False, True],
                          eval_result["first_finished"])
      self.assertAllEqual([False, False, False, True, True],
                          eval_result["step_finished"])
      self.assertEqual(output_dtype.sample_id,
                       eval_result["step_outputs"].sample_id.dtype)
      self.assertAllEqual(
          np.argmax(eval_result["step_outputs"].rnn_output, -1),
          eval_result["step_outputs"].sample_id)
예제 #5
0
  def _testStepWithScheduledOutputTrainingHelper(
      self, sampling_probability, use_next_inputs_fn, use_auxiliary_inputs):
    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    cell_depth = input_depth
    if use_auxiliary_inputs:
      auxiliary_input_depth = 4
      auxiliary_inputs = np.random.randn(
          batch_size, max_time, auxiliary_input_depth).astype(np.float32)
    else:
      auxiliary_inputs = None

    with self.cached_session(use_gpu=True):
      inputs = np.random.randn(batch_size, max_time,
                               input_depth).astype(np.float32)
      input_t = constant_op.constant(inputs)
      cell = rnn_cell.LSTMCell(cell_depth)
      sampling_probability = constant_op.constant(sampling_probability)

      if use_next_inputs_fn:
        def next_inputs_fn(outputs):
          # Use deterministic function for test.
          samples = math_ops.argmax(outputs, axis=1)
          return array_ops.one_hot(samples, cell_depth, dtype=dtypes.float32)
      else:
        next_inputs_fn = None

      sampler = sampler_py.ScheduledOutputTrainingSampler(
          sampling_probability=sampling_probability,
          time_major=False,
          next_inputs_fn=next_inputs_fn)
      initial_state = cell.zero_state(
          dtype=dtypes.float32, batch_size=batch_size)

      my_decoder = basic_decoder.BasicDecoderV2(
          cell=cell,
          sampler=sampler)

      (first_finished,
       first_inputs,
       first_state) = my_decoder.initialize(input_t,
                                            sequence_length=sequence_length,
                                            initial_state=initial_state,
                                            auxiliary_inputs=auxiliary_inputs)
      output_size = my_decoder.output_size
      output_dtype = my_decoder.output_dtype
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(cell_depth,
                                           tensor_shape.TensorShape([])),
          output_size)
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
          output_dtype)

      (step_outputs, step_state, step_next_inputs,
       step_finished) = my_decoder.step(
           constant_op.constant(0), first_inputs, first_state)

      if use_next_inputs_fn:
        output_after_next_inputs_fn = next_inputs_fn(step_outputs.rnn_output)

      batch_size_t = my_decoder.batch_size

      self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(
          isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
      self.assertEqual((batch_size,), step_outputs[1].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

      self.evaluate(variables.global_variables_initializer())

      fetches = {
          "batch_size": batch_size_t,
          "first_finished": first_finished,
          "first_inputs": first_inputs,
          "first_state": first_state,
          "step_outputs": step_outputs,
          "step_state": step_state,
          "step_next_inputs": step_next_inputs,
          "step_finished": step_finished
      }
      if use_next_inputs_fn:
        fetches["output_after_next_inputs_fn"] = output_after_next_inputs_fn

      eval_result = self.evaluate(fetches)

      self.assertAllEqual([False, False, False, False, True],
                          eval_result["first_finished"])
      self.assertAllEqual([False, False, False, True, True],
                          eval_result["step_finished"])

      sample_ids = eval_result["step_outputs"].sample_id
      self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
      batch_where_not_sampling = np.where(np.logical_not(sample_ids))
      batch_where_sampling = np.where(sample_ids)

      auxiliary_inputs_to_concat = (
          auxiliary_inputs[:, 1] if use_auxiliary_inputs else
          np.array([]).reshape(batch_size, 0).astype(np.float32))

      expected_next_sampling_inputs = np.concatenate(
          (eval_result["output_after_next_inputs_fn"][batch_where_sampling]
           if use_next_inputs_fn else
           eval_result["step_outputs"].rnn_output[batch_where_sampling],
           auxiliary_inputs_to_concat[batch_where_sampling]),
          axis=-1)
      self.assertAllClose(
          eval_result["step_next_inputs"][batch_where_sampling],
          expected_next_sampling_inputs)

      self.assertAllClose(
          eval_result["step_next_inputs"][batch_where_not_sampling],
          np.concatenate(
              (np.squeeze(inputs[batch_where_not_sampling, 1], axis=0),
               auxiliary_inputs_to_concat[batch_where_not_sampling]),
              axis=-1))
예제 #6
0
  def testStepWithScheduledEmbeddingTrainingHelper(self):
    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    vocabulary_size = 10

    with self.cached_session(use_gpu=True):
      inputs = np.random.randn(
          batch_size, max_time, input_depth).astype(np.float32)
      input_t = constant_op.constant(inputs)
      embeddings = np.random.randn(
          vocabulary_size, input_depth).astype(np.float32)
      half = constant_op.constant(0.5)
      cell = rnn_cell.LSTMCell(vocabulary_size)
      sampler = sampler_py.ScheduledEmbeddingTrainingSampler(
          sampling_probability=half,
          time_major=False)
      initial_state = cell.zero_state(
          dtype=dtypes.float32, batch_size=batch_size)
      my_decoder = basic_decoder.BasicDecoderV2(
          cell=cell,
          sampler=sampler)
      (first_finished, first_inputs, first_state) = my_decoder.initialize(
          input_t, sequence_length=sequence_length, embedding=embeddings,
          initial_state=initial_state)
      output_size = my_decoder.output_size
      output_dtype = my_decoder.output_dtype
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(vocabulary_size,
                                           tensor_shape.TensorShape([])),
          output_size)
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
          output_dtype)

      (step_outputs, step_state, step_next_inputs,
       step_finished) = my_decoder.step(
           constant_op.constant(0), first_inputs, first_state)
      batch_size_t = my_decoder.batch_size

      self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(
          isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual((batch_size, vocabulary_size),
                       step_outputs[0].get_shape())
      self.assertEqual((batch_size,), step_outputs[1].get_shape())
      self.assertEqual((batch_size, vocabulary_size),
                       first_state[0].get_shape())
      self.assertEqual((batch_size, vocabulary_size),
                       first_state[1].get_shape())
      self.assertEqual((batch_size, vocabulary_size),
                       step_state[0].get_shape())
      self.assertEqual((batch_size, vocabulary_size),
                       step_state[1].get_shape())
      self.assertEqual((batch_size, input_depth),
                       step_next_inputs.get_shape())

      self.evaluate(variables.global_variables_initializer())
      eval_result = self.evaluate({
          "batch_size": batch_size_t,
          "first_finished": first_finished,
          "first_inputs": first_inputs,
          "first_state": first_state,
          "step_outputs": step_outputs,
          "step_state": step_state,
          "step_next_inputs": step_next_inputs,
          "step_finished": step_finished
      })

      self.assertAllEqual([False, False, False, False, True],
                          eval_result["first_finished"])
      self.assertAllEqual([False, False, False, True, True],
                          eval_result["step_finished"])
      sample_ids = eval_result["step_outputs"].sample_id
      self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
      batch_where_not_sampling = np.where(sample_ids == -1)
      batch_where_sampling = np.where(sample_ids > -1)
      self.assertAllClose(
          eval_result["step_next_inputs"][batch_where_sampling],
          embeddings[sample_ids[batch_where_sampling]])
      self.assertAllClose(
          eval_result["step_next_inputs"][batch_where_not_sampling],
          np.squeeze(inputs[batch_where_not_sampling, 1], axis=0))
예제 #7
0
  def testStepWithSampleEmbeddingHelper(self):
    batch_size = 5
    vocabulary_size = 7
    cell_depth = vocabulary_size  # cell's logits must match vocabulary size
    input_depth = 10
    np.random.seed(0)
    start_tokens = np.random.randint(0, vocabulary_size, size=batch_size)
    end_token = 1

    with self.cached_session(use_gpu=True):
      embeddings = np.random.randn(vocabulary_size,
                                   input_depth).astype(np.float32)
      embeddings_t = constant_op.constant(embeddings)
      cell = rnn_cell.LSTMCell(vocabulary_size)
      sampler = sampler_py.SampleEmbeddingSampler(seed=0)
      initial_state = cell.zero_state(
          dtype=dtypes.float32, batch_size=batch_size)
      my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler)
      (first_finished,
       first_inputs,
       first_state) = my_decoder.initialize(embeddings_t,
                                            start_tokens=start_tokens,
                                            end_token=end_token,
                                            initial_state=initial_state)
      output_size = my_decoder.output_size
      output_dtype = my_decoder.output_dtype
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(cell_depth,
                                           tensor_shape.TensorShape([])),
          output_size)
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
          output_dtype)

      (step_outputs, step_state, step_next_inputs,
       step_finished) = my_decoder.step(
           constant_op.constant(0), first_inputs, first_state)
      batch_size_t = my_decoder.batch_size

      self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(
          isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
      self.assertEqual((batch_size,), step_outputs[1].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

      self.evaluate(variables.global_variables_initializer())
      eval_result = self.evaluate({
          "batch_size": batch_size_t,
          "first_finished": first_finished,
          "first_inputs": first_inputs,
          "first_state": first_state,
          "step_outputs": step_outputs,
          "step_state": step_state,
          "step_next_inputs": step_next_inputs,
          "step_finished": step_finished
      })

      sample_ids = eval_result["step_outputs"].sample_id
      self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
      expected_step_finished = (sample_ids == end_token)
      expected_step_next_inputs = embeddings[sample_ids]
      self.assertAllEqual(expected_step_finished,
                          eval_result["step_finished"])
      self.assertAllEqual(expected_step_next_inputs,
                          eval_result["step_next_inputs"])
예제 #8
0
    def _testWithMaybeMultiAttention(self,
                                     is_multi,
                                     create_attention_mechanisms,
                                     expected_final_output,
                                     expected_final_state,
                                     attention_mechanism_depths,
                                     alignment_history=False,
                                     expected_final_alignment_history=None,
                                     attention_layer_sizes=None,
                                     attention_layers=None,
                                     create_query_layer=False,
                                     create_memory_layer=True,
                                     create_attention_kwargs=None):
        # Allow is_multi to be True with a single mechanism to enable test for
        # passing in a single mechanism in a list.
        assert len(create_attention_mechanisms) == 1 or is_multi
        encoder_sequence_length = [3, 2, 3, 1, 1]
        decoder_sequence_length = [2, 0, 1, 2, 3]
        batch_size = 5
        encoder_max_time = 8
        decoder_max_time = 4
        input_depth = 7
        encoder_output_depth = 10
        cell_depth = 9
        create_attention_kwargs = create_attention_kwargs or {}

        if attention_layer_sizes is not None:
            # Compute sum of attention_layer_sizes. Use encoder_output_depth if None.
            attention_depth = sum(
                attention_layer_size or encoder_output_depth
                for attention_layer_size in attention_layer_sizes)
        elif attention_layers is not None:
            # Compute sum of attention_layers output depth.
            attention_depth = sum(
                attention_layer.compute_output_shape(
                    [batch_size, cell_depth +
                     encoder_output_depth]).dims[-1].value
                for attention_layer in attention_layers)
        else:
            attention_depth = encoder_output_depth * len(
                create_attention_mechanisms)

        decoder_inputs = np.random.randn(batch_size, decoder_max_time,
                                         input_depth).astype(np.float32)
        encoder_outputs = np.random.randn(batch_size, encoder_max_time,
                                          encoder_output_depth).astype(
                                              np.float32)

        attention_mechanisms = []
        for creator, depth in zip(create_attention_mechanisms,
                                  attention_mechanism_depths):
            # Create a memory layer with deterministic initializer to avoid randomness
            # in the test between graph and eager.
            if create_query_layer:
                create_attention_kwargs["query_layer"] = keras.layers.Dense(
                    depth, kernel_initializer="ones", use_bias=False)
            if create_memory_layer:
                create_attention_kwargs["memory_layer"] = keras.layers.Dense(
                    depth, kernel_initializer="ones", use_bias=False)

            attention_mechanisms.append(
                creator(units=depth,
                        memory=encoder_outputs,
                        memory_sequence_length=encoder_sequence_length,
                        **create_attention_kwargs))

        with self.cached_session(use_gpu=True):
            attention_layer_size = attention_layer_sizes
            attention_layer = attention_layers
            if not is_multi:
                if attention_layer_size is not None:
                    attention_layer_size = attention_layer_size[0]
                if attention_layer is not None:
                    attention_layer = attention_layer[0]
            cell = keras.layers.LSTMCell(cell_depth,
                                         recurrent_activation="sigmoid",
                                         kernel_initializer="ones",
                                         recurrent_initializer="ones")
            cell = wrapper.AttentionWrapper(
                cell,
                attention_mechanisms if is_multi else attention_mechanisms[0],
                attention_layer_size=attention_layer_size,
                alignment_history=alignment_history,
                attention_layer=attention_layer)
            if cell._attention_layers is not None:
                for layer in cell._attention_layers:
                    if getattr(layer, "kernel_initializer") is None:
                        layer.kernel_initializer = initializers.glorot_uniform(
                            seed=1337)

            sampler = sampler_py.TrainingSampler()
            my_decoder = basic_decoder.BasicDecoderV2(cell=cell,
                                                      sampler=sampler)
            initial_state = cell.get_initial_state(dtype=dtypes.float32,
                                                   batch_size=batch_size)
            final_outputs, final_state, _ = my_decoder(
                decoder_inputs,
                initial_state=initial_state,
                sequence_length=decoder_sequence_length)

            self.assertIsInstance(final_outputs,
                                  basic_decoder.BasicDecoderOutput)
            self.assertIsInstance(final_state, wrapper.AttentionWrapperState)

            expected_time = (expected_final_state.time
                             if context.executing_eagerly() else None)
            self.assertEqual(
                (batch_size, expected_time, attention_depth),
                tuple(final_outputs.rnn_output.get_shape().as_list()))
            self.assertEqual(
                (batch_size, expected_time),
                tuple(final_outputs.sample_id.get_shape().as_list()))

            self.assertEqual(
                (batch_size, attention_depth),
                tuple(final_state.attention.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state[0].get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state[1].get_shape().as_list()))

            if alignment_history:
                if is_multi:
                    state_alignment_history = []
                    for history_array in final_state.alignment_history:
                        history = history_array.stack()
                        self.assertEqual(
                            (expected_time, batch_size, encoder_max_time),
                            tuple(history.get_shape().as_list()))
                        state_alignment_history.append(history)
                    state_alignment_history = tuple(state_alignment_history)
                else:
                    state_alignment_history = final_state.alignment_history.stack(
                    )
                    self.assertEqual(
                        (expected_time, batch_size, encoder_max_time),
                        tuple(state_alignment_history.get_shape().as_list()))
                nest.assert_same_structure(
                    cell.state_size, cell.zero_state(batch_size,
                                                     dtypes.float32))
                # Remove the history from final_state for purposes of the
                # remainder of the tests.
                final_state = final_state._replace(alignment_history=())  # pylint: disable=protected-access
            else:
                state_alignment_history = ()

            self.evaluate(variables.global_variables_initializer())
            eval_result = self.evaluate({
                "final_outputs":
                final_outputs,
                "final_state":
                final_state,
                "state_alignment_history":
                state_alignment_history,
            })

            final_output_info = nest.map_structure(
                get_result_summary, eval_result["final_outputs"])
            final_state_info = nest.map_structure(get_result_summary,
                                                  eval_result["final_state"])
            print("final_output_info: ", final_output_info)
            print("final_state_info: ", final_state_info)

            nest.map_structure(self.assertAllCloseOrEqual,
                               expected_final_output, final_output_info)
            nest.map_structure(self.assertAllCloseOrEqual,
                               expected_final_state, final_state_info)
            if alignment_history:  # by default, the wrapper emits attention as output
                final_alignment_history_info = nest.map_structure(
                    get_result_summary, eval_result["state_alignment_history"])
                print("final_alignment_history_info: ",
                      final_alignment_history_info)
                nest.map_structure(
                    self.assertAllCloseOrEqual,
                    # outputs are batch major but the stacked TensorArray is time major
                    expected_final_alignment_history,
                    final_alignment_history_info)
예제 #9
0
    def _testDecodeRNN(self, time_major, maximum_iterations=None):

        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        cell_depth = 10
        max_out = max(sequence_length)

        with self.cached_session(use_gpu=True):
            if time_major:
                inputs = np.random.randn(max_time, batch_size,
                                         input_depth).astype(np.float32)
            else:
                inputs = np.random.randn(batch_size, max_time,
                                         input_depth).astype(np.float32)
            input_t = constant_op.constant(inputs)
            cell = rnn_cell.LSTMCell(cell_depth)
            sampler = sampler_py.TrainingSampler(time_major=time_major)
            my_decoder = basic_decoder.BasicDecoderV2(
                cell=cell,
                sampler=sampler,
                output_time_major=time_major,
                maximum_iterations=maximum_iterations)

            initial_state = cell.zero_state(dtype=dtypes.float32,
                                            batch_size=batch_size)
            (final_outputs, unused_final_state,
             final_sequence_length) = my_decoder(
                 input_t,
                 initial_state=initial_state,
                 sequence_length=sequence_length)

            def _t(shape):
                if time_major:
                    return (shape[1], shape[0]) + shape[2:]
                return shape

            if not context.executing_eagerly():
                self.assertEqual(
                    (batch_size, ),
                    tuple(final_sequence_length.get_shape().as_list()))
                self.assertEqual(
                    _t((batch_size, None, cell_depth)),
                    tuple(final_outputs.rnn_output.get_shape().as_list()))
                self.assertEqual(
                    _t((batch_size, None)),
                    tuple(final_outputs.sample_id.get_shape().as_list()))

            self.evaluate(variables.global_variables_initializer())
            final_outputs = self.evaluate(final_outputs)
            final_sequence_length = self.evaluate(final_sequence_length)

            # Mostly a smoke test
            time_steps = max_out
            expected_length = sequence_length
            if maximum_iterations is not None:
                time_steps = min(max_out, maximum_iterations)
                expected_length = [
                    min(x, maximum_iterations) for x in expected_length
                ]
            if context.executing_eagerly() and maximum_iterations != 0:
                # Only check the shape of output when maximum_iterations > 0, see
                # b/123431432 for more details.
                self.assertEqual(_t((batch_size, time_steps, cell_depth)),
                                 final_outputs.rnn_output.shape)
                self.assertEqual(_t((batch_size, time_steps)),
                                 final_outputs.sample_id.shape)
            self.assertItemsEqual(expected_length, final_sequence_length)