def create_decoder(self, encoded, inputs, speaker_embed, train=True): config = self.config attention_mech = wrapper.BahdanauAttention( config.attention_units, encoded, memory_sequence_length=inputs['text_length']) inner_cell = [GRUCell(config.decoder_units) for _ in range(3)] decoder_cell = OutputProjectionWrapper( InputProjectionWrapper(ResidualWrapper(MultiRNNCell(inner_cell)), config.decoder_units), config.mel_features * config.r) # feed in rth frame at each time step decoder_frame_input = \ lambda inputs, attention: tf.concat( [self.pre_net(tf.slice(inputs, [0, (config.r - 1)*config.mel_features], [-1, -1]), dropout=config.audio_dropout_prob, train=train), attention] , -1) cell = wrapper.AttentionWrapper( decoder_cell, attention_mech, attention_layer_size=config.attention_units, cell_input_fn=decoder_frame_input, alignment_history=True, output_attention=False) if train: if config.scheduled_sample: print("if train if config.scheduled_sample: %s" % str( (inputs['mel'], inputs['speech_length'], config.scheduled_sample))) decoder_helper = helper.ScheduledOutputTrainingHelper( inputs['mel'], inputs['speech_length'], config.scheduled_sample) else: decoder_helper = helper.TrainingHelper(inputs['mel'], inputs['speech_length']) else: decoder_helper = ops.InferenceHelper( tf.shape(inputs['text'])[0], config.mel_features * config.r) initial_state = cell.zero_state(dtype=tf.float32, batch_size=tf.shape(inputs['text'])[0]) #if speaker_embed is not None: #initial_state.attention = tf.layers.dense(speaker_embed, config.attention_units) dec = basic_decoder.BasicDecoder(cell, decoder_helper, initial_state) return dec
def _testStepWithScheduledOutputTrainingHelper( self, sampling_probability, use_next_inputs_fn, use_auxiliary_inputs): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 max_time = 8 input_depth = 7 cell_depth = input_depth if use_auxiliary_inputs: auxiliary_input_depth = 4 auxiliary_inputs = np.random.randn( batch_size, max_time, auxiliary_input_depth).astype(np.float32) else: auxiliary_inputs = None with self.session(use_gpu=True) as sess: inputs = np.random.randn(batch_size, max_time, input_depth).astype(np.float32) cell = rnn_cell.LSTMCell(cell_depth) sampling_probability = constant_op.constant(sampling_probability) if use_next_inputs_fn: def next_inputs_fn(outputs): # Use deterministic function for test. samples = math_ops.argmax(outputs, axis=1) return array_ops.one_hot(samples, cell_depth, dtype=dtypes.float32) else: next_inputs_fn = None helper = helper_py.ScheduledOutputTrainingHelper( inputs=inputs, sequence_length=sequence_length, sampling_probability=sampling_probability, time_major=False, next_inputs_fn=next_inputs_fn, auxiliary_inputs=auxiliary_inputs) my_decoder = basic_decoder.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size)) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( basic_decoder.BasicDecoderOutput(cell_depth, tensor_shape.TensorShape([])), output_size) self.assertEqual( basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), output_dtype) (first_finished, first_inputs, first_state) = my_decoder.initialize() (step_outputs, step_state, step_next_inputs, step_finished) = my_decoder.step( constant_op.constant(0), first_inputs, first_state) if use_next_inputs_fn: output_after_next_inputs_fn = next_inputs_fn(step_outputs.rnn_output) batch_size_t = my_decoder.batch_size self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) self.assertTrue( isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) self.assertEqual((batch_size,), step_outputs[1].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) sess.run(variables.global_variables_initializer()) fetches = { "batch_size": batch_size_t, "first_finished": first_finished, "first_inputs": first_inputs, "first_state": first_state, "step_outputs": step_outputs, "step_state": step_state, "step_next_inputs": step_next_inputs, "step_finished": step_finished } if use_next_inputs_fn: fetches["output_after_next_inputs_fn"] = output_after_next_inputs_fn sess_results = sess.run(fetches) self.assertAllEqual([False, False, False, False, True], sess_results["first_finished"]) self.assertAllEqual([False, False, False, True, True], sess_results["step_finished"]) sample_ids = sess_results["step_outputs"].sample_id self.assertEqual(output_dtype.sample_id, sample_ids.dtype) batch_where_not_sampling = np.where(np.logical_not(sample_ids)) batch_where_sampling = np.where(sample_ids) auxiliary_inputs_to_concat = ( auxiliary_inputs[:, 1] if use_auxiliary_inputs else np.array([]).reshape(batch_size, 0).astype(np.float32)) expected_next_sampling_inputs = np.concatenate( (sess_results["output_after_next_inputs_fn"][batch_where_sampling] if use_next_inputs_fn else sess_results["step_outputs"].rnn_output[batch_where_sampling], auxiliary_inputs_to_concat[batch_where_sampling]), axis=-1) self.assertAllClose( sess_results["step_next_inputs"][batch_where_sampling], expected_next_sampling_inputs) self.assertAllClose( sess_results["step_next_inputs"][batch_where_not_sampling], np.concatenate( (np.squeeze(inputs[batch_where_not_sampling, 1], axis=0), auxiliary_inputs_to_concat[batch_where_not_sampling]), axis=-1))
def _testStepWithScheduledOutputTrainingHelper(self, use_next_input_layer): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 max_time = 8 input_depth = 7 cell_depth = input_depth if use_next_input_layer: cell_depth = 6 with self.test_session() as sess: inputs = np.random.randn(batch_size, max_time, input_depth).astype(np.float32) cell = core_rnn_cell.LSTMCell(cell_depth) half = constant_op.constant(0.5) next_input_layer = None if use_next_input_layer: next_input_layer = layers_core.Dense(input_depth, use_bias=False) helper = helper_py.ScheduledOutputTrainingHelper( inputs=inputs, sequence_length=sequence_length, sampling_probability=half, time_major=False, next_input_layer=next_input_layer) my_decoder = basic_decoder.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state(dtype=dtypes.float32, batch_size=batch_size)) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( basic_decoder.BasicDecoderOutput(cell_depth, tensor_shape.TensorShape([])), output_size) self.assertEqual( basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), output_dtype) (first_finished, first_inputs, first_state) = my_decoder.initialize() (step_outputs, step_state, step_next_inputs, step_finished) = my_decoder.step(constant_op.constant(0), first_inputs, first_state) if use_next_input_layer: output_after_next_input_layer = next_input_layer( step_outputs.rnn_output) batch_size_t = my_decoder.batch_size self.assertTrue( isinstance(first_state, core_rnn_cell.LSTMStateTuple)) self.assertTrue( isinstance(step_state, core_rnn_cell.LSTMStateTuple)) self.assertTrue( isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) self.assertEqual((batch_size, ), step_outputs[1].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) sess.run(variables.global_variables_initializer()) fetches = { "batch_size": batch_size_t, "first_finished": first_finished, "first_inputs": first_inputs, "first_state": first_state, "step_outputs": step_outputs, "step_state": step_state, "step_next_inputs": step_next_inputs, "step_finished": step_finished } if use_next_input_layer: fetches[ "output_after_next_input_layer"] = output_after_next_input_layer sess_results = sess.run(fetches) self.assertAllEqual([False, False, False, False, True], sess_results["first_finished"]) self.assertAllEqual([False, False, False, True, True], sess_results["step_finished"]) sample_ids = sess_results["step_outputs"].sample_id batch_where_not_sampling = np.where(np.logical_not(sample_ids)) batch_where_sampling = np.where(sample_ids) if use_next_input_layer: self.assertAllClose( sess_results["step_next_inputs"][batch_where_sampling], sess_results["output_after_next_input_layer"] [batch_where_sampling]) else: self.assertAllClose( sess_results["step_next_inputs"][batch_where_sampling], sess_results["step_outputs"]. rnn_output[batch_where_sampling]) self.assertAllClose( sess_results["step_next_inputs"][batch_where_not_sampling], np.squeeze(inputs[batch_where_not_sampling, 1], axis=1))