示例#1
0
    def _testDynamicDecodeRNNWithBasicTrainingSamplerMatchesDynamicRNN(
            self, use_sequence_length):
        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        cell_depth = 10
        max_out = max(sequence_length)

        with self.test_session() as sess:
            inputs = np.random.randn(batch_size, max_time,
                                     input_depth).astype(np.float32)

            cell = core_rnn_cell.LSTMCell(cell_depth)
            zero_state = cell.zero_state(dtype=dtypes.float32,
                                         batch_size=batch_size)
            sampler = sampling_decoder.BasicTrainingSampler(
                inputs, sequence_length)
            my_decoder = sampling_decoder.BasicSamplingDecoder(
                cell=cell, sampler=sampler, initial_state=zero_state)

            # Match the variable scope of dynamic_rnn below so we end up
            # using the same variables
            with vs.variable_scope("root") as scope:
                final_decoder_outputs, final_decoder_state = decoder.dynamic_decode_rnn(
                    my_decoder,
                    # impute_finished=True ensures outputs and final state
                    # match those of dynamic_rnn called with sequence_length not None
                    impute_finished=use_sequence_length,
                    scope=scope)

            with vs.variable_scope(scope, reuse=True) as scope:
                final_rnn_outputs, final_rnn_state = rnn.dynamic_rnn(
                    cell,
                    inputs,
                    sequence_length=sequence_length
                    if use_sequence_length else None,
                    initial_state=zero_state,
                    scope=scope)

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                "final_decoder_outputs": final_decoder_outputs,
                "final_decoder_state": final_decoder_state,
                "final_rnn_outputs": final_rnn_outputs,
                "final_rnn_state": final_rnn_state
            })

            # Decoder only runs out to max_out; ensure values are identical
            # to dynamic_rnn, which also zeros out outputs and passes along state.
            self.assertAllClose(
                sess_results["final_decoder_outputs"].rnn_output,
                sess_results["final_rnn_outputs"][:, 0:max_out, :])
            if use_sequence_length:
                self.assertAllClose(sess_results["final_decoder_state"],
                                    sess_results["final_rnn_state"])
示例#2
0
    def _testDynamicDecodeRNN(self, time_major):

        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        cell_depth = 10
        max_out = max(sequence_length)

        with self.test_session() as sess:
            if time_major:
                inputs = np.random.randn(max_time, batch_size,
                                         input_depth).astype(np.float32)
            else:
                inputs = np.random.randn(batch_size, max_time,
                                         input_depth).astype(np.float32)
            cell = core_rnn_cell.LSTMCell(cell_depth)
            sampler = sampling_decoder.BasicTrainingSampler(
                inputs, sequence_length, time_major=time_major)
            my_decoder = sampling_decoder.BasicSamplingDecoder(
                cell=cell,
                sampler=sampler,
                initial_state=cell.zero_state(dtype=dtypes.float32,
                                              batch_size=batch_size))

            final_outputs, final_state = decoder.dynamic_decode_rnn(
                my_decoder, output_time_major=time_major)

            def _t(shape):
                if time_major:
                    return (shape[1], shape[0]) + shape[2:]
                return shape

            self.assertTrue(
                isinstance(final_outputs,
                           sampling_decoder.SamplingDecoderOutput))
            self.assertTrue(
                isinstance(final_state, core_rnn_cell.LSTMStateTuple))

            self.assertEqual(
                _t((batch_size, None, cell_depth)),
                tuple(final_outputs.rnn_output.get_shape().as_list()))
            self.assertEqual(
                _t((batch_size, None)),
                tuple(final_outputs.sample_id.get_shape().as_list()))

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                "final_outputs": final_outputs,
                "final_state": final_state
            })

            self.assertEqual(_t((batch_size, max_out, cell_depth)),
                             sess_results["final_outputs"].rnn_output.shape)
            self.assertEqual(_t((batch_size, max_out)),
                             sess_results["final_outputs"].sample_id.shape)
  def testStepWithBasicTrainingSampler(self):
    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    cell_depth = 10

    with self.test_session() as sess:
      inputs = np.random.randn(batch_size, max_time,
                               input_depth).astype(np.float32)
      cell = core_rnn_cell.LSTMCell(cell_depth)
      sampler = sampling_decoder.BasicTrainingSampler(
          inputs, sequence_length, time_major=False)
      my_decoder = sampling_decoder.BasicSamplingDecoder(
          cell=cell,
          sampler=sampler,
          initial_state=cell.zero_state(
              dtype=dtypes.float32, batch_size=batch_size))
      output_size = my_decoder.output_size
      output_dtype = my_decoder.output_dtype
      self.assertEqual(
          sampling_decoder.SamplingDecoderOutput(cell_depth,
                                                 tensor_shape.TensorShape([])),
          output_size)
      self.assertEqual(
          sampling_decoder.SamplingDecoderOutput(dtypes.float32, dtypes.int32),
          output_dtype)

      (first_finished, first_inputs, first_state) = my_decoder.initialize()
      (step_outputs, step_state, step_next_inputs,
       step_finished) = my_decoder.step(
           constant_op.constant(0), first_inputs, first_state)
      batch_size_t = my_decoder.batch_size

      self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
      self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
      self.assertTrue(
          isinstance(step_outputs, sampling_decoder.SamplingDecoderOutput))
      self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
      self.assertEqual((batch_size,), step_outputs[1].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          "batch_size": batch_size_t,
          "first_finished": first_finished,
          "first_inputs": first_inputs,
          "first_state": first_state,
          "step_outputs": step_outputs,
          "step_state": step_state,
          "step_next_inputs": step_next_inputs,
          "step_finished": step_finished
      })

      self.assertAllEqual([False, False, False, False, True],
                          sess_results["first_finished"])
      self.assertAllEqual([False, False, False, True, True],
                          sess_results["step_finished"])
      self.assertAllEqual(
          np.argmax(sess_results["step_outputs"].rnn_output, -1),
          sess_results["step_outputs"].sample_id)