def test_step_with_scheduled_output_training_helper(sampling_probability, use_next_inputs_fn, use_auxiliary_inputs): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 max_time = 8 input_depth = 7 cell_depth = input_depth if use_auxiliary_inputs: auxiliary_input_depth = 4 auxiliary_inputs = np.random.randn( batch_size, max_time, auxiliary_input_depth).astype(np.float32) else: auxiliary_inputs = None inputs = np.random.randn(batch_size, max_time, input_depth).astype(np.float32) input_t = tf.constant(inputs) cell = tf.keras.layers.LSTMCell(cell_depth) sampling_probability = tf.constant(sampling_probability) if use_next_inputs_fn: def next_inputs_fn(outputs): # Use deterministic function for test. samples = tf.argmax(outputs, axis=1) return tf.one_hot(samples, cell_depth, dtype=tf.float32) else: next_inputs_fn = None sampler = sampler_py.ScheduledOutputTrainingSampler( sampling_probability=sampling_probability, time_major=False, next_inputs_fn=next_inputs_fn, ) initial_state = cell.get_initial_state(batch_size=batch_size, dtype=tf.float32) my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler) (first_finished, first_inputs, first_state) = my_decoder.initialize( input_t, sequence_length=sequence_length, initial_state=initial_state, auxiliary_inputs=auxiliary_inputs, ) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype assert (basic_decoder.BasicDecoderOutput(cell_depth, tf.TensorShape( [])) == output_size) assert basic_decoder.BasicDecoderOutput(tf.float32, tf.int32) == output_dtype ( step_outputs, step_state, step_next_inputs, step_finished, ) = my_decoder.step(tf.constant(0), first_inputs, first_state) if use_next_inputs_fn: output_after_next_inputs_fn = next_inputs_fn(step_outputs.rnn_output) batch_size_t = my_decoder.batch_size assert len(first_state) == 2 assert len(step_state) == 2 assert isinstance(step_outputs, basic_decoder.BasicDecoderOutput) assert (batch_size, cell_depth) == step_outputs[0].shape assert (batch_size, ) == step_outputs[1].shape assert (batch_size, cell_depth) == first_state[0].shape assert (batch_size, cell_depth) == first_state[1].shape assert (batch_size, cell_depth) == step_state[0].shape assert (batch_size, cell_depth) == step_state[1].shape fetches = { "batch_size": batch_size_t.numpy(), "first_finished": first_finished.numpy(), "first_inputs": first_inputs.numpy(), "first_state": np.asanyarray(first_state), "step_outputs": step_outputs, "step_state": np.asanyarray(step_state), "step_next_inputs": step_next_inputs.numpy(), "step_finished": step_finished.numpy(), } if use_next_inputs_fn: fetches["output_after_next_inputs_fn"] = output_after_next_inputs_fn eval_result = fetches np.testing.assert_equal( np.asanyarray([False, False, False, False, True]), eval_result["first_finished"], ) np.testing.assert_equal( np.asanyarray([False, False, False, True, True]), eval_result["step_finished"], ) sample_ids = eval_result["step_outputs"].sample_id.numpy() assert output_dtype.sample_id == sample_ids.dtype batch_where_not_sampling = np.where(np.logical_not(sample_ids)) batch_where_sampling = np.where(sample_ids) auxiliary_inputs_to_concat = (auxiliary_inputs[:, 1] if use_auxiliary_inputs else np.array([]).reshape( batch_size, 0).astype(np.float32)) expected_next_sampling_inputs = np.concatenate( ( eval_result["output_after_next_inputs_fn"].numpy() [batch_where_sampling] if use_next_inputs_fn else eval_result["step_outputs"].rnn_output.numpy() [batch_where_sampling], auxiliary_inputs_to_concat[batch_where_sampling], ), axis=-1, ) np.testing.assert_equal( eval_result["step_next_inputs"][batch_where_sampling], expected_next_sampling_inputs, ) np.testing.assert_equal( eval_result["step_next_inputs"][batch_where_not_sampling], np.concatenate( ( np.squeeze(inputs[batch_where_not_sampling, 1], axis=0), auxiliary_inputs_to_concat[batch_where_not_sampling], ), axis=-1, ), )
def _testStepWithScheduledOutputTrainingHelper( self, sampling_probability, use_next_inputs_fn, use_auxiliary_inputs ): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 max_time = 8 input_depth = 7 cell_depth = input_depth if use_auxiliary_inputs: auxiliary_input_depth = 4 auxiliary_inputs = np.random.randn( batch_size, max_time, auxiliary_input_depth ).astype(np.float32) else: auxiliary_inputs = None with self.cached_session(use_gpu=True): inputs = np.random.randn(batch_size, max_time, input_depth).astype( np.float32 ) input_t = tf.constant(inputs) cell = tf.keras.layers.LSTMCell(cell_depth) sampling_probability = tf.constant(sampling_probability) if use_next_inputs_fn: def next_inputs_fn(outputs): # Use deterministic function for test. samples = tf.argmax(outputs, axis=1) return tf.one_hot(samples, cell_depth, dtype=tf.float32) else: next_inputs_fn = None sampler = sampler_py.ScheduledOutputTrainingSampler( sampling_probability=sampling_probability, time_major=False, next_inputs_fn=next_inputs_fn, ) initial_state = cell.get_initial_state( batch_size=batch_size, dtype=tf.float32 ) my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler) (first_finished, first_inputs, first_state) = my_decoder.initialize( input_t, sequence_length=sequence_length, initial_state=initial_state, auxiliary_inputs=auxiliary_inputs, ) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( basic_decoder.BasicDecoderOutput(cell_depth, tf.TensorShape([])), output_size, ) self.assertEqual( basic_decoder.BasicDecoderOutput(tf.float32, tf.int32), output_dtype ) ( step_outputs, step_state, step_next_inputs, step_finished, ) = my_decoder.step(tf.constant(0), first_inputs, first_state) if use_next_inputs_fn: output_after_next_inputs_fn = next_inputs_fn(step_outputs.rnn_output) batch_size_t = my_decoder.batch_size self.assertLen(first_state, 2) self.assertLen(step_state, 2) self.assertTrue(isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) self.assertEqual((batch_size,), step_outputs[1].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) self.evaluate(tf.compat.v1.global_variables_initializer()) fetches = { "batch_size": batch_size_t, "first_finished": first_finished, "first_inputs": first_inputs, "first_state": first_state, "step_outputs": step_outputs, "step_state": step_state, "step_next_inputs": step_next_inputs, "step_finished": step_finished, } if use_next_inputs_fn: fetches["output_after_next_inputs_fn"] = output_after_next_inputs_fn eval_result = self.evaluate(fetches) self.assertAllEqual( [False, False, False, False, True], eval_result["first_finished"] ) self.assertAllEqual( [False, False, False, True, True], eval_result["step_finished"] ) sample_ids = eval_result["step_outputs"].sample_id self.assertEqual(output_dtype.sample_id, sample_ids.dtype) batch_where_not_sampling = np.where(np.logical_not(sample_ids)) batch_where_sampling = np.where(sample_ids) auxiliary_inputs_to_concat = ( auxiliary_inputs[:, 1] if use_auxiliary_inputs else np.array([]).reshape(batch_size, 0).astype(np.float32) ) expected_next_sampling_inputs = np.concatenate( ( eval_result["output_after_next_inputs_fn"][batch_where_sampling] if use_next_inputs_fn else eval_result["step_outputs"].rnn_output[batch_where_sampling], auxiliary_inputs_to_concat[batch_where_sampling], ), axis=-1, ) self.assertAllClose( eval_result["step_next_inputs"][batch_where_sampling], expected_next_sampling_inputs, ) self.assertAllClose( eval_result["step_next_inputs"][batch_where_not_sampling], np.concatenate( ( np.squeeze(inputs[batch_where_not_sampling, 1], axis=0), auxiliary_inputs_to_concat[batch_where_not_sampling], ), axis=-1, ), )