Esempio n. 1
0
  def testLuongScaledDType(self):
    # Test case for GitHub issue 18099
    for dt in [np.float16, np.float32, np.float64]:
      num_units = 128
      encoder_outputs = array_ops.placeholder(dt, shape=[64, None, 256])
      encoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
      decoder_inputs = array_ops.placeholder(dt, shape=[64, None, 128])
      decoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
      batch_size = 64
      attention_mechanism = wrapper.LuongAttention(
          num_units=num_units,
          memory=encoder_outputs,
          memory_sequence_length=encoder_sequence_length,
          scale=True,
          dtype=dt,
      )
      cell = rnn_cell.LSTMCell(num_units)
      cell = wrapper.AttentionWrapper(cell, attention_mechanism)

      helper = helper_py.TrainingHelper(decoder_inputs,
                                        decoder_sequence_length)
      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dt, batch_size=batch_size))

      final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)
      self.assertTrue(
          isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual(final_outputs.rnn_output.dtype, dt)
      self.assertTrue(
          isinstance(final_state, wrapper.AttentionWrapperState))
      self.assertTrue(
          isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))
Esempio n. 2
0
  def _testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN(
      self, use_sequence_length):
    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    cell_depth = 10
    max_out = max(sequence_length)

    with self.session(use_gpu=True) as sess:
      inputs = np.random.randn(batch_size, max_time,
                               input_depth).astype(np.float32)

      cell = rnn_cell.LSTMCell(cell_depth)
      zero_state = cell.zero_state(dtype=dtypes.float32, batch_size=batch_size)
      helper = helper_py.TrainingHelper(inputs, sequence_length)
      my_decoder = basic_decoder.BasicDecoder(
          cell=cell, helper=helper, initial_state=zero_state)

      # Match the variable scope of dynamic_rnn below so we end up
      # using the same variables
      with vs.variable_scope("root") as scope:
        final_decoder_outputs, final_decoder_state, _ = decoder.dynamic_decode(
            my_decoder,
            # impute_finished=True ensures outputs and final state
            # match those of dynamic_rnn called with sequence_length not None
            impute_finished=use_sequence_length,
            scope=scope)

      with vs.variable_scope(scope, reuse=True) as scope:
        final_rnn_outputs, final_rnn_state = rnn.dynamic_rnn(
            cell,
            inputs,
            sequence_length=sequence_length if use_sequence_length else None,
            initial_state=zero_state,
            scope=scope)

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          "final_decoder_outputs": final_decoder_outputs,
          "final_decoder_state": final_decoder_state,
          "final_rnn_outputs": final_rnn_outputs,
          "final_rnn_state": final_rnn_state
      })

      # Decoder only runs out to max_out; ensure values are identical
      # to dynamic_rnn, which also zeros out outputs and passes along state.
      self.assertAllClose(sess_results["final_decoder_outputs"].rnn_output,
                          sess_results["final_rnn_outputs"][:, 0:max_out, :])
      if use_sequence_length:
        self.assertAllClose(sess_results["final_decoder_state"],
                            sess_results["final_rnn_state"])
Esempio n. 3
0
  def _testWithMaybeMultiAttention(self,
                                   is_multi,
                                   create_attention_mechanisms,
                                   expected_final_output,
                                   expected_final_state,
                                   attention_mechanism_depths,
                                   alignment_history=False,
                                   expected_final_alignment_history=None,
                                   attention_layer_sizes=None,
                                   attention_layers=None,
                                   name=''):
    # Allow is_multi to be True with a single mechanism to enable test for
    # passing in a single mechanism in a list.
    assert len(create_attention_mechanisms) == 1 or is_multi
    encoder_sequence_length = [3, 2, 3, 1, 1]
    decoder_sequence_length = [2, 0, 1, 2, 3]
    batch_size = 5
    encoder_max_time = 8
    decoder_max_time = 4
    input_depth = 7
    encoder_output_depth = 10
    cell_depth = 9

    if attention_layer_sizes is not None:
      # Compute sum of attention_layer_sizes. Use encoder_output_depth if None.
      attention_depth = sum(attention_layer_size or encoder_output_depth
                            for attention_layer_size in attention_layer_sizes)
    elif attention_layers is not None:
      # Compute sum of attention_layers output depth.
      attention_depth = sum(
          attention_layer.compute_output_shape(
              [batch_size, cell_depth + encoder_output_depth]).dims[-1].value
          for attention_layer in attention_layers)
    else:
      attention_depth = encoder_output_depth * len(create_attention_mechanisms)

    decoder_inputs = array_ops.placeholder_with_default(
        np.random.randn(batch_size, decoder_max_time,
                        input_depth).astype(np.float32),
        shape=(None, None, input_depth))
    encoder_outputs = array_ops.placeholder_with_default(
        np.random.randn(batch_size, encoder_max_time,
                        encoder_output_depth).astype(np.float32),
        shape=(None, None, encoder_output_depth))

    attention_mechanisms = [
        creator(num_units=depth,
                memory=encoder_outputs,
                memory_sequence_length=encoder_sequence_length)
        for creator, depth in zip(create_attention_mechanisms,
                                  attention_mechanism_depths)]

    with self.session(use_gpu=True) as sess:
      with vs.variable_scope(
          'root',
          initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
        attention_layer_size = attention_layer_sizes
        attention_layer = attention_layers
        if not is_multi:
          if attention_layer_size is not None:
            attention_layer_size = attention_layer_size[0]
          if attention_layer is not None:
            attention_layer = attention_layer[0]
        cell = rnn_cell.LSTMCell(cell_depth)
        cell = wrapper.AttentionWrapper(
            cell,
            attention_mechanisms if is_multi else attention_mechanisms[0],
            attention_layer_size=attention_layer_size,
            alignment_history=alignment_history,
            attention_layer=attention_layer)
        helper = helper_py.TrainingHelper(decoder_inputs,
                                          decoder_sequence_length)
        my_decoder = basic_decoder.BasicDecoder(
            cell=cell,
            helper=helper,
            initial_state=cell.zero_state(
                dtype=dtypes.float32, batch_size=batch_size))

        final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)

      self.assertTrue(
          isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
      self.assertTrue(
          isinstance(final_state, wrapper.AttentionWrapperState))
      self.assertTrue(
          isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))

      self.assertEqual((batch_size, None, attention_depth),
                       tuple(final_outputs.rnn_output.get_shape().as_list()))
      self.assertEqual((batch_size, None),
                       tuple(final_outputs.sample_id.get_shape().as_list()))

      self.assertEqual((batch_size, attention_depth),
                       tuple(final_state.attention.get_shape().as_list()))
      self.assertEqual((batch_size, cell_depth),
                       tuple(final_state.cell_state.c.get_shape().as_list()))
      self.assertEqual((batch_size, cell_depth),
                       tuple(final_state.cell_state.h.get_shape().as_list()))

      if alignment_history:
        if is_multi:
          state_alignment_history = []
          for history_array in final_state.alignment_history:
            history = history_array.stack()
            self.assertEqual(
                (None, batch_size, None),
                tuple(history.get_shape().as_list()))
            state_alignment_history.append(history)
          state_alignment_history = tuple(state_alignment_history)
        else:
          state_alignment_history = final_state.alignment_history.stack()
          self.assertEqual(
              (None, batch_size, None),
              tuple(state_alignment_history.get_shape().as_list()))
        nest.assert_same_structure(
            cell.state_size,
            cell.zero_state(batch_size, dtypes.float32))
        # Remove the history from final_state for purposes of the
        # remainder of the tests.
        final_state = final_state._replace(alignment_history=())  # pylint: disable=protected-access
      else:
        state_alignment_history = ()

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          'final_outputs': final_outputs,
          'final_state': final_state,
          'state_alignment_history': state_alignment_history,
      })

      final_output_info = nest.map_structure(get_result_summary,
                                             sess_results['final_outputs'])
      final_state_info = nest.map_structure(get_result_summary,
                                            sess_results['final_state'])
      print(name)
      print('Copy/paste:\nexpected_final_output = %s' % str(final_output_info))
      print('expected_final_state = %s' % str(final_state_info))
      nest.map_structure(self.assertAllCloseOrEqual, expected_final_output,
                         final_output_info)
      nest.map_structure(self.assertAllCloseOrEqual, expected_final_state,
                         final_state_info)
      if alignment_history:  # by default, the wrapper emits attention as output
        final_alignment_history_info = nest.map_structure(
            get_result_summary, sess_results['state_alignment_history'])
        print('expected_final_alignment_history = %s' %
              str(final_alignment_history_info))
        nest.map_structure(
            self.assertAllCloseOrEqual,
            # outputs are batch major but the stacked TensorArray is time major
            expected_final_alignment_history,
            final_alignment_history_info)
Esempio n. 4
0
  def _testDynamicDecodeRNN(self, time_major, maximum_iterations=None):

    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    cell_depth = 10
    max_out = max(sequence_length)

    with self.session(use_gpu=True) as sess:
      if time_major:
        inputs = np.random.randn(max_time, batch_size,
                                 input_depth).astype(np.float32)
      else:
        inputs = np.random.randn(batch_size, max_time,
                                 input_depth).astype(np.float32)
      cell = rnn_cell.LSTMCell(cell_depth)
      helper = helper_py.TrainingHelper(
          inputs, sequence_length, time_major=time_major)
      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dtypes.float32, batch_size=batch_size))

      final_outputs, final_state, final_sequence_length = (
          decoder.dynamic_decode(my_decoder, output_time_major=time_major,
                                 maximum_iterations=maximum_iterations))

      def _t(shape):
        if time_major:
          return (shape[1], shape[0]) + shape[2:]
        return shape

      self.assertTrue(
          isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
      self.assertTrue(isinstance(final_state, rnn_cell.LSTMStateTuple))

      self.assertEqual(
          (batch_size,),
          tuple(final_sequence_length.get_shape().as_list()))
      self.assertEqual(
          _t((batch_size, None, cell_depth)),
          tuple(final_outputs.rnn_output.get_shape().as_list()))
      self.assertEqual(
          _t((batch_size, None)),
          tuple(final_outputs.sample_id.get_shape().as_list()))

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          "final_outputs": final_outputs,
          "final_state": final_state,
          "final_sequence_length": final_sequence_length,
      })

      # Mostly a smoke test
      time_steps = max_out
      expected_length = sequence_length
      if maximum_iterations is not None:
        time_steps = min(max_out, maximum_iterations)
        expected_length = [min(x, maximum_iterations) for x in expected_length]
      self.assertEqual(
          _t((batch_size, time_steps, cell_depth)),
          sess_results["final_outputs"].rnn_output.shape)
      self.assertEqual(
          _t((batch_size, time_steps)),
          sess_results["final_outputs"].sample_id.shape)
      self.assertItemsEqual(expected_length,
                            sess_results["final_sequence_length"])
Esempio n. 5
0
    def testStepWithInferenceHelperMultilabel(self):
        batch_size = 5
        vocabulary_size = 7
        cell_depth = vocabulary_size
        start_token = 0
        end_token = 6

        start_inputs = array_ops.one_hot(
            np.ones(batch_size) * start_token, vocabulary_size)

        # The sample function samples independent bernoullis from the logits.
        sample_fn = (
            lambda x: helper_py.bernoulli_sample(logits=x, dtype=dtypes.bool))
        # The next inputs are a one-hot encoding of the sampled labels.
        next_inputs_fn = math_ops.to_float
        end_fn = lambda sample_ids: sample_ids[:, end_token]

        with self.session(use_gpu=True) as sess:
            with variable_scope.variable_scope(
                    "testStepWithInferenceHelper",
                    initializer=init_ops.constant_initializer(0.01)):
                cell = rnn_cell.LSTMCell(vocabulary_size)
                helper = helper_py.InferenceHelper(
                    sample_fn,
                    sample_shape=[cell_depth],
                    sample_dtype=dtypes.bool,
                    start_inputs=start_inputs,
                    end_fn=end_fn,
                    next_inputs_fn=next_inputs_fn)
                my_decoder = basic_decoder.BasicDecoder(
                    cell=cell,
                    helper=helper,
                    initial_state=cell.zero_state(dtype=dtypes.float32,
                                                  batch_size=batch_size))
                output_size = my_decoder.output_size
                output_dtype = my_decoder.output_dtype
                self.assertEqual(
                    basic_decoder.BasicDecoderOutput(cell_depth, cell_depth),
                    output_size)
                self.assertEqual(
                    basic_decoder.BasicDecoderOutput(dtypes.float32,
                                                     dtypes.bool),
                    output_dtype)

                (first_finished, first_inputs,
                 first_state) = my_decoder.initialize()
                (step_outputs, step_state, step_next_inputs,
                 step_finished) = my_decoder.step(constant_op.constant(0),
                                                  first_inputs, first_state)
                batch_size_t = my_decoder.batch_size

                self.assertTrue(
                    isinstance(first_state, rnn_cell.LSTMStateTuple))
                self.assertTrue(isinstance(step_state,
                                           rnn_cell.LSTMStateTuple))
                self.assertTrue(
                    isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
                self.assertEqual((batch_size, cell_depth),
                                 step_outputs[0].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 step_outputs[1].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 first_state[0].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 first_state[1].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 step_state[0].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 step_state[1].get_shape())

                sess.run(variables.global_variables_initializer())
                sess_results = sess.run({
                    "batch_size": batch_size_t,
                    "first_finished": first_finished,
                    "first_inputs": first_inputs,
                    "first_state": first_state,
                    "step_outputs": step_outputs,
                    "step_state": step_state,
                    "step_next_inputs": step_next_inputs,
                    "step_finished": step_finished
                })

                sample_ids = sess_results["step_outputs"].sample_id
                self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
                expected_step_finished = sample_ids[:, end_token]
                expected_step_next_inputs = sample_ids.astype(np.float32)
                self.assertAllEqual(expected_step_finished,
                                    sess_results["step_finished"])
                self.assertAllEqual(expected_step_next_inputs,
                                    sess_results["step_next_inputs"])
Esempio n. 6
0
    def testStepWithInferenceHelperCategorical(self):
        batch_size = 5
        vocabulary_size = 7
        cell_depth = vocabulary_size
        start_token = 0
        end_token = 6

        start_inputs = array_ops.one_hot(
            np.ones(batch_size) * start_token, vocabulary_size)

        # The sample function samples categorically from the logits.
        sample_fn = lambda x: helper_py.categorical_sample(logits=x)
        # The next inputs are a one-hot encoding of the sampled labels.
        next_inputs_fn = (lambda x: array_ops.one_hot(
            x, vocabulary_size, dtype=dtypes.float32))
        end_fn = lambda sample_ids: math_ops.equal(sample_ids, end_token)

        with self.session(use_gpu=True) as sess:
            with variable_scope.variable_scope(
                    "testStepWithInferenceHelper",
                    initializer=init_ops.constant_initializer(0.01)):
                cell = rnn_cell.LSTMCell(vocabulary_size)
                helper = helper_py.InferenceHelper(
                    sample_fn,
                    sample_shape=(),
                    sample_dtype=dtypes.int32,
                    start_inputs=start_inputs,
                    end_fn=end_fn,
                    next_inputs_fn=next_inputs_fn)
                my_decoder = basic_decoder.BasicDecoder(
                    cell=cell,
                    helper=helper,
                    initial_state=cell.zero_state(dtype=dtypes.float32,
                                                  batch_size=batch_size))
                output_size = my_decoder.output_size
                output_dtype = my_decoder.output_dtype
                self.assertEqual(
                    basic_decoder.BasicDecoderOutput(
                        cell_depth, tensor_shape.TensorShape([])), output_size)
                self.assertEqual(
                    basic_decoder.BasicDecoderOutput(dtypes.float32,
                                                     dtypes.int32),
                    output_dtype)

                (first_finished, first_inputs,
                 first_state) = my_decoder.initialize()
                (step_outputs, step_state, step_next_inputs,
                 step_finished) = my_decoder.step(constant_op.constant(0),
                                                  first_inputs, first_state)
                batch_size_t = my_decoder.batch_size

                self.assertTrue(
                    isinstance(first_state, rnn_cell.LSTMStateTuple))
                self.assertTrue(isinstance(step_state,
                                           rnn_cell.LSTMStateTuple))
                self.assertTrue(
                    isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
                self.assertEqual((batch_size, cell_depth),
                                 step_outputs[0].get_shape())
                self.assertEqual((batch_size, ), step_outputs[1].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 first_state[0].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 first_state[1].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 step_state[0].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 step_state[1].get_shape())

                sess.run(variables.global_variables_initializer())
                sess_results = sess.run({
                    "batch_size": batch_size_t,
                    "first_finished": first_finished,
                    "first_inputs": first_inputs,
                    "first_state": first_state,
                    "step_outputs": step_outputs,
                    "step_state": step_state,
                    "step_next_inputs": step_next_inputs,
                    "step_finished": step_finished
                })

                sample_ids = sess_results["step_outputs"].sample_id
                self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
                expected_step_finished = (sample_ids == end_token)
                expected_step_next_inputs = np.zeros(
                    (batch_size, vocabulary_size))
                expected_step_next_inputs[np.arange(batch_size),
                                          sample_ids] = 1.0
                self.assertAllEqual(expected_step_finished,
                                    sess_results["step_finished"])
                self.assertAllEqual(expected_step_next_inputs,
                                    sess_results["step_next_inputs"])
Esempio n. 7
0
    def _testStepWithTrainingHelper(self, use_output_layer):
        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        cell_depth = 10
        output_layer_depth = 3

        with self.session(use_gpu=True) as sess:
            inputs = np.random.randn(batch_size, max_time,
                                     input_depth).astype(np.float32)
            cell = rnn_cell.LSTMCell(cell_depth)
            helper = helper_py.TrainingHelper(inputs,
                                              sequence_length,
                                              time_major=False)
            if use_output_layer:
                output_layer = layers_core.Dense(output_layer_depth,
                                                 use_bias=False)
                expected_output_depth = output_layer_depth
            else:
                output_layer = None
                expected_output_depth = cell_depth
            my_decoder = basic_decoder.BasicDecoder(
                cell=cell,
                helper=helper,
                initial_state=cell.zero_state(dtype=dtypes.float32,
                                              batch_size=batch_size),
                output_layer=output_layer)
            output_size = my_decoder.output_size
            output_dtype = my_decoder.output_dtype
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(expected_output_depth,
                                                 tensor_shape.TensorShape([])),
                output_size)
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
                output_dtype)

            (first_finished, first_inputs,
             first_state) = my_decoder.initialize()
            (step_outputs, step_state, step_next_inputs,
             step_finished) = my_decoder.step(constant_op.constant(0),
                                              first_inputs, first_state)
            batch_size_t = my_decoder.batch_size

            self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
            self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
            self.assertTrue(
                isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
            self.assertEqual((batch_size, expected_output_depth),
                             step_outputs[0].get_shape())
            self.assertEqual((batch_size, ), step_outputs[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[1].get_shape())

            if use_output_layer:
                # The output layer was accessed
                self.assertEqual(len(output_layer.variables), 1)

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                "batch_size": batch_size_t,
                "first_finished": first_finished,
                "first_inputs": first_inputs,
                "first_state": first_state,
                "step_outputs": step_outputs,
                "step_state": step_state,
                "step_next_inputs": step_next_inputs,
                "step_finished": step_finished
            })

            self.assertAllEqual([False, False, False, False, True],
                                sess_results["first_finished"])
            self.assertAllEqual([False, False, False, True, True],
                                sess_results["step_finished"])
            self.assertEqual(output_dtype.sample_id,
                             sess_results["step_outputs"].sample_id.dtype)
            self.assertAllEqual(
                np.argmax(sess_results["step_outputs"].rnn_output, -1),
                sess_results["step_outputs"].sample_id)
Esempio n. 8
0
    def _testStepWithScheduledOutputTrainingHelper(self, sampling_probability,
                                                   use_next_inputs_fn,
                                                   use_auxiliary_inputs):
        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        cell_depth = input_depth
        if use_auxiliary_inputs:
            auxiliary_input_depth = 4
            auxiliary_inputs = np.random.randn(
                batch_size, max_time, auxiliary_input_depth).astype(np.float32)
        else:
            auxiliary_inputs = None

        with self.session(use_gpu=True) as sess:
            inputs = np.random.randn(batch_size, max_time,
                                     input_depth).astype(np.float32)
            cell = rnn_cell.LSTMCell(cell_depth)
            sampling_probability = constant_op.constant(sampling_probability)

            if use_next_inputs_fn:

                def next_inputs_fn(outputs):
                    # Use deterministic function for test.
                    samples = math_ops.argmax(outputs, axis=1)
                    return array_ops.one_hot(samples,
                                             cell_depth,
                                             dtype=dtypes.float32)
            else:
                next_inputs_fn = None

            helper = helper_py.ScheduledOutputTrainingHelper(
                inputs=inputs,
                sequence_length=sequence_length,
                sampling_probability=sampling_probability,
                time_major=False,
                next_inputs_fn=next_inputs_fn,
                auxiliary_inputs=auxiliary_inputs)

            my_decoder = basic_decoder.BasicDecoder(
                cell=cell,
                helper=helper,
                initial_state=cell.zero_state(dtype=dtypes.float32,
                                              batch_size=batch_size))

            output_size = my_decoder.output_size
            output_dtype = my_decoder.output_dtype
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(cell_depth,
                                                 tensor_shape.TensorShape([])),
                output_size)
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
                output_dtype)

            (first_finished, first_inputs,
             first_state) = my_decoder.initialize()
            (step_outputs, step_state, step_next_inputs,
             step_finished) = my_decoder.step(constant_op.constant(0),
                                              first_inputs, first_state)

            if use_next_inputs_fn:
                output_after_next_inputs_fn = next_inputs_fn(
                    step_outputs.rnn_output)

            batch_size_t = my_decoder.batch_size

            self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
            self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
            self.assertTrue(
                isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
            self.assertEqual((batch_size, cell_depth),
                             step_outputs[0].get_shape())
            self.assertEqual((batch_size, ), step_outputs[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[1].get_shape())

            sess.run(variables.global_variables_initializer())

            fetches = {
                "batch_size": batch_size_t,
                "first_finished": first_finished,
                "first_inputs": first_inputs,
                "first_state": first_state,
                "step_outputs": step_outputs,
                "step_state": step_state,
                "step_next_inputs": step_next_inputs,
                "step_finished": step_finished
            }
            if use_next_inputs_fn:
                fetches[
                    "output_after_next_inputs_fn"] = output_after_next_inputs_fn

            sess_results = sess.run(fetches)

            self.assertAllEqual([False, False, False, False, True],
                                sess_results["first_finished"])
            self.assertAllEqual([False, False, False, True, True],
                                sess_results["step_finished"])

            sample_ids = sess_results["step_outputs"].sample_id
            self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
            batch_where_not_sampling = np.where(np.logical_not(sample_ids))
            batch_where_sampling = np.where(sample_ids)

            auxiliary_inputs_to_concat = (
                auxiliary_inputs[:, 1] if use_auxiliary_inputs else np.array(
                    []).reshape(batch_size, 0).astype(np.float32))

            expected_next_sampling_inputs = np.concatenate(
                (sess_results["output_after_next_inputs_fn"]
                 [batch_where_sampling] if use_next_inputs_fn else
                 sess_results["step_outputs"].rnn_output[batch_where_sampling],
                 auxiliary_inputs_to_concat[batch_where_sampling]),
                axis=-1)
            self.assertAllClose(
                sess_results["step_next_inputs"][batch_where_sampling],
                expected_next_sampling_inputs)

            self.assertAllClose(
                sess_results["step_next_inputs"][batch_where_not_sampling],
                np.concatenate(
                    (np.squeeze(inputs[batch_where_not_sampling, 1], axis=0),
                     auxiliary_inputs_to_concat[batch_where_not_sampling]),
                    axis=-1))
Esempio n. 9
0
    def testStepWithScheduledEmbeddingTrainingHelper(self):
        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        vocabulary_size = 10

        with self.session(use_gpu=True) as sess:
            inputs = np.random.randn(batch_size, max_time,
                                     input_depth).astype(np.float32)
            embeddings = np.random.randn(vocabulary_size,
                                         input_depth).astype(np.float32)
            half = constant_op.constant(0.5)
            cell = rnn_cell.LSTMCell(vocabulary_size)
            helper = helper_py.ScheduledEmbeddingTrainingHelper(
                inputs=inputs,
                sequence_length=sequence_length,
                embedding=embeddings,
                sampling_probability=half,
                time_major=False)
            my_decoder = basic_decoder.BasicDecoder(
                cell=cell,
                helper=helper,
                initial_state=cell.zero_state(dtype=dtypes.float32,
                                              batch_size=batch_size))
            output_size = my_decoder.output_size
            output_dtype = my_decoder.output_dtype
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(vocabulary_size,
                                                 tensor_shape.TensorShape([])),
                output_size)
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
                output_dtype)

            (first_finished, first_inputs,
             first_state) = my_decoder.initialize()
            (step_outputs, step_state, step_next_inputs,
             step_finished) = my_decoder.step(constant_op.constant(0),
                                              first_inputs, first_state)
            batch_size_t = my_decoder.batch_size

            self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
            self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
            self.assertTrue(
                isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
            self.assertEqual((batch_size, vocabulary_size),
                             step_outputs[0].get_shape())
            self.assertEqual((batch_size, ), step_outputs[1].get_shape())
            self.assertEqual((batch_size, vocabulary_size),
                             first_state[0].get_shape())
            self.assertEqual((batch_size, vocabulary_size),
                             first_state[1].get_shape())
            self.assertEqual((batch_size, vocabulary_size),
                             step_state[0].get_shape())
            self.assertEqual((batch_size, vocabulary_size),
                             step_state[1].get_shape())
            self.assertEqual((batch_size, input_depth),
                             step_next_inputs.get_shape())

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                "batch_size": batch_size_t,
                "first_finished": first_finished,
                "first_inputs": first_inputs,
                "first_state": first_state,
                "step_outputs": step_outputs,
                "step_state": step_state,
                "step_next_inputs": step_next_inputs,
                "step_finished": step_finished
            })

            self.assertAllEqual([False, False, False, False, True],
                                sess_results["first_finished"])
            self.assertAllEqual([False, False, False, True, True],
                                sess_results["step_finished"])
            sample_ids = sess_results["step_outputs"].sample_id
            self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
            batch_where_not_sampling = np.where(sample_ids == -1)
            batch_where_sampling = np.where(sample_ids > -1)
            self.assertAllClose(
                sess_results["step_next_inputs"][batch_where_sampling],
                embeddings[sample_ids[batch_where_sampling]])
            self.assertAllClose(
                sess_results["step_next_inputs"][batch_where_not_sampling],
                np.squeeze(inputs[batch_where_not_sampling, 1]))
Esempio n. 10
0
    def testStepWithSampleEmbeddingHelper(self):
        batch_size = 5
        vocabulary_size = 7
        cell_depth = vocabulary_size  # cell's logits must match vocabulary size
        input_depth = 10
        np.random.seed(0)
        start_tokens = np.random.randint(0, vocabulary_size, size=batch_size)
        end_token = 1

        with self.session(use_gpu=True) as sess:
            with variable_scope.variable_scope(
                    "testStepWithSampleEmbeddingHelper",
                    initializer=init_ops.constant_initializer(0.01)):
                embeddings = np.random.randn(vocabulary_size,
                                             input_depth).astype(np.float32)
                cell = rnn_cell.LSTMCell(vocabulary_size)
                helper = helper_py.SampleEmbeddingHelper(embeddings,
                                                         start_tokens,
                                                         end_token,
                                                         seed=0)
                my_decoder = basic_decoder.BasicDecoder(
                    cell=cell,
                    helper=helper,
                    initial_state=cell.zero_state(dtype=dtypes.float32,
                                                  batch_size=batch_size))
                output_size = my_decoder.output_size
                output_dtype = my_decoder.output_dtype
                self.assertEqual(
                    basic_decoder.BasicDecoderOutput(
                        cell_depth, tensor_shape.TensorShape([])), output_size)
                self.assertEqual(
                    basic_decoder.BasicDecoderOutput(dtypes.float32,
                                                     dtypes.int32),
                    output_dtype)

                (first_finished, first_inputs,
                 first_state) = my_decoder.initialize()
                (step_outputs, step_state, step_next_inputs,
                 step_finished) = my_decoder.step(constant_op.constant(0),
                                                  first_inputs, first_state)
                batch_size_t = my_decoder.batch_size

                self.assertTrue(
                    isinstance(first_state, rnn_cell.LSTMStateTuple))
                self.assertTrue(isinstance(step_state,
                                           rnn_cell.LSTMStateTuple))
                self.assertTrue(
                    isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
                self.assertEqual((batch_size, cell_depth),
                                 step_outputs[0].get_shape())
                self.assertEqual((batch_size, ), step_outputs[1].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 first_state[0].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 first_state[1].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 step_state[0].get_shape())
                self.assertEqual((batch_size, cell_depth),
                                 step_state[1].get_shape())

                sess.run(variables.global_variables_initializer())
                sess_results = sess.run({
                    "batch_size": batch_size_t,
                    "first_finished": first_finished,
                    "first_inputs": first_inputs,
                    "first_state": first_state,
                    "step_outputs": step_outputs,
                    "step_state": step_state,
                    "step_next_inputs": step_next_inputs,
                    "step_finished": step_finished
                })

                sample_ids = sess_results["step_outputs"].sample_id
                self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
                expected_step_finished = (sample_ids == end_token)
                expected_step_next_inputs = embeddings[sample_ids]
                self.assertAllEqual(expected_step_finished,
                                    sess_results["step_finished"])
                self.assertAllEqual(expected_step_next_inputs,
                                    sess_results["step_next_inputs"])