def _testDynamicDecodeRNNWithBasicTrainingSamplerMatchesDynamicRNN( self, use_sequence_length): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 max_time = 8 input_depth = 7 cell_depth = 10 max_out = max(sequence_length) with self.test_session() as sess: inputs = np.random.randn(batch_size, max_time, input_depth).astype(np.float32) cell = core_rnn_cell.LSTMCell(cell_depth) zero_state = cell.zero_state(dtype=dtypes.float32, batch_size=batch_size) sampler = sampling_decoder.BasicTrainingSampler( inputs, sequence_length) my_decoder = sampling_decoder.BasicSamplingDecoder( cell=cell, sampler=sampler, initial_state=zero_state) # Match the variable scope of dynamic_rnn below so we end up # using the same variables with vs.variable_scope("root") as scope: final_decoder_outputs, final_decoder_state = decoder.dynamic_decode_rnn( my_decoder, # impute_finished=True ensures outputs and final state # match those of dynamic_rnn called with sequence_length not None impute_finished=use_sequence_length, scope=scope) with vs.variable_scope(scope, reuse=True) as scope: final_rnn_outputs, final_rnn_state = rnn.dynamic_rnn( cell, inputs, sequence_length=sequence_length if use_sequence_length else None, initial_state=zero_state, scope=scope) sess.run(variables.global_variables_initializer()) sess_results = sess.run({ "final_decoder_outputs": final_decoder_outputs, "final_decoder_state": final_decoder_state, "final_rnn_outputs": final_rnn_outputs, "final_rnn_state": final_rnn_state }) # Decoder only runs out to max_out; ensure values are identical # to dynamic_rnn, which also zeros out outputs and passes along state. self.assertAllClose( sess_results["final_decoder_outputs"].rnn_output, sess_results["final_rnn_outputs"][:, 0:max_out, :]) if use_sequence_length: self.assertAllClose(sess_results["final_decoder_state"], sess_results["final_rnn_state"])
def _testDynamicDecodeRNN(self, time_major): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 max_time = 8 input_depth = 7 cell_depth = 10 max_out = max(sequence_length) with self.test_session() as sess: if time_major: inputs = np.random.randn(max_time, batch_size, input_depth).astype(np.float32) else: inputs = np.random.randn(batch_size, max_time, input_depth).astype(np.float32) cell = core_rnn_cell.LSTMCell(cell_depth) sampler = sampling_decoder.BasicTrainingSampler( inputs, sequence_length, time_major=time_major) my_decoder = sampling_decoder.BasicSamplingDecoder( cell=cell, sampler=sampler, initial_state=cell.zero_state(dtype=dtypes.float32, batch_size=batch_size)) final_outputs, final_state = decoder.dynamic_decode_rnn( my_decoder, output_time_major=time_major) def _t(shape): if time_major: return (shape[1], shape[0]) + shape[2:] return shape self.assertTrue( isinstance(final_outputs, sampling_decoder.SamplingDecoderOutput)) self.assertTrue( isinstance(final_state, core_rnn_cell.LSTMStateTuple)) self.assertEqual( _t((batch_size, None, cell_depth)), tuple(final_outputs.rnn_output.get_shape().as_list())) self.assertEqual( _t((batch_size, None)), tuple(final_outputs.sample_id.get_shape().as_list())) sess.run(variables.global_variables_initializer()) sess_results = sess.run({ "final_outputs": final_outputs, "final_state": final_state }) self.assertEqual(_t((batch_size, max_out, cell_depth)), sess_results["final_outputs"].rnn_output.shape) self.assertEqual(_t((batch_size, max_out)), sess_results["final_outputs"].sample_id.shape)
def testStepWithBasicTrainingSampler(self): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 max_time = 8 input_depth = 7 cell_depth = 10 with self.test_session() as sess: inputs = np.random.randn(batch_size, max_time, input_depth).astype(np.float32) cell = core_rnn_cell.LSTMCell(cell_depth) sampler = sampling_decoder.BasicTrainingSampler( inputs, sequence_length, time_major=False) my_decoder = sampling_decoder.BasicSamplingDecoder( cell=cell, sampler=sampler, initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size)) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( sampling_decoder.SamplingDecoderOutput(cell_depth, tensor_shape.TensorShape([])), output_size) self.assertEqual( sampling_decoder.SamplingDecoderOutput(dtypes.float32, dtypes.int32), output_dtype) (first_finished, first_inputs, first_state) = my_decoder.initialize() (step_outputs, step_state, step_next_inputs, step_finished) = my_decoder.step( constant_op.constant(0), first_inputs, first_state) batch_size_t = my_decoder.batch_size self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple)) self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple)) self.assertTrue( isinstance(step_outputs, sampling_decoder.SamplingDecoderOutput)) self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) self.assertEqual((batch_size,), step_outputs[1].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) sess.run(variables.global_variables_initializer()) sess_results = sess.run({ "batch_size": batch_size_t, "first_finished": first_finished, "first_inputs": first_inputs, "first_state": first_state, "step_outputs": step_outputs, "step_state": step_state, "step_next_inputs": step_next_inputs, "step_finished": step_finished }) self.assertAllEqual([False, False, False, False, True], sess_results["first_finished"]) self.assertAllEqual([False, False, False, True, True], sess_results["step_finished"]) self.assertAllEqual( np.argmax(sess_results["step_outputs"].rnn_output, -1), sess_results["step_outputs"].sample_id)