def create_decoder(self, encoded, inputs, speaker_embed, train=True):
        config = self.config
        attention_mech = wrapper.BahdanauAttention(
            config.attention_units,
            encoded,
            memory_sequence_length=inputs['text_length'])

        inner_cell = [GRUCell(config.decoder_units) for _ in range(3)]

        decoder_cell = OutputProjectionWrapper(
            InputProjectionWrapper(ResidualWrapper(MultiRNNCell(inner_cell)),
                                   config.decoder_units),
            config.mel_features * config.r)

        # feed in rth frame at each time step
        decoder_frame_input = \
            lambda inputs, attention: tf.concat(
                    [self.pre_net(tf.slice(inputs,
                        [0, (config.r - 1)*config.mel_features], [-1, -1]),
                        dropout=config.audio_dropout_prob,
                        train=train),
                    attention]
                , -1)

        cell = wrapper.AttentionWrapper(
            decoder_cell,
            attention_mech,
            attention_layer_size=config.attention_units,
            cell_input_fn=decoder_frame_input,
            alignment_history=True,
            output_attention=False)

        if train:
            if config.scheduled_sample:
                print("if train if config.scheduled_sample: %s" % str(
                    (inputs['mel'], inputs['speech_length'],
                     config.scheduled_sample)))
                decoder_helper = helper.ScheduledOutputTrainingHelper(
                    inputs['mel'], inputs['speech_length'],
                    config.scheduled_sample)
            else:
                decoder_helper = helper.TrainingHelper(inputs['mel'],
                                                       inputs['speech_length'])
        else:
            decoder_helper = ops.InferenceHelper(
                tf.shape(inputs['text'])[0], config.mel_features * config.r)

        initial_state = cell.zero_state(dtype=tf.float32,
                                        batch_size=tf.shape(inputs['text'])[0])

        #if speaker_embed is not None:
        #initial_state.attention = tf.layers.dense(speaker_embed, config.attention_units)

        dec = basic_decoder.BasicDecoder(cell, decoder_helper, initial_state)

        return dec
Пример #2
0
  def _testStepWithScheduledOutputTrainingHelper(
      self, sampling_probability, use_next_inputs_fn, use_auxiliary_inputs):
    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    cell_depth = input_depth
    if use_auxiliary_inputs:
      auxiliary_input_depth = 4
      auxiliary_inputs = np.random.randn(
          batch_size, max_time, auxiliary_input_depth).astype(np.float32)
    else:
      auxiliary_inputs = None

    with self.session(use_gpu=True) as sess:
      inputs = np.random.randn(batch_size, max_time,
                               input_depth).astype(np.float32)
      cell = rnn_cell.LSTMCell(cell_depth)
      sampling_probability = constant_op.constant(sampling_probability)

      if use_next_inputs_fn:
        def next_inputs_fn(outputs):
          # Use deterministic function for test.
          samples = math_ops.argmax(outputs, axis=1)
          return array_ops.one_hot(samples, cell_depth, dtype=dtypes.float32)
      else:
        next_inputs_fn = None

      helper = helper_py.ScheduledOutputTrainingHelper(
          inputs=inputs,
          sequence_length=sequence_length,
          sampling_probability=sampling_probability,
          time_major=False,
          next_inputs_fn=next_inputs_fn,
          auxiliary_inputs=auxiliary_inputs)

      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dtypes.float32, batch_size=batch_size))

      output_size = my_decoder.output_size
      output_dtype = my_decoder.output_dtype
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(cell_depth,
                                           tensor_shape.TensorShape([])),
          output_size)
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
          output_dtype)

      (first_finished, first_inputs, first_state) = my_decoder.initialize()
      (step_outputs, step_state, step_next_inputs,
       step_finished) = my_decoder.step(
           constant_op.constant(0), first_inputs, first_state)

      if use_next_inputs_fn:
        output_after_next_inputs_fn = next_inputs_fn(step_outputs.rnn_output)

      batch_size_t = my_decoder.batch_size

      self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(
          isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
      self.assertEqual((batch_size,), step_outputs[1].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

      sess.run(variables.global_variables_initializer())

      fetches = {
          "batch_size": batch_size_t,
          "first_finished": first_finished,
          "first_inputs": first_inputs,
          "first_state": first_state,
          "step_outputs": step_outputs,
          "step_state": step_state,
          "step_next_inputs": step_next_inputs,
          "step_finished": step_finished
      }
      if use_next_inputs_fn:
        fetches["output_after_next_inputs_fn"] = output_after_next_inputs_fn

      sess_results = sess.run(fetches)

      self.assertAllEqual([False, False, False, False, True],
                          sess_results["first_finished"])
      self.assertAllEqual([False, False, False, True, True],
                          sess_results["step_finished"])

      sample_ids = sess_results["step_outputs"].sample_id
      self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
      batch_where_not_sampling = np.where(np.logical_not(sample_ids))
      batch_where_sampling = np.where(sample_ids)

      auxiliary_inputs_to_concat = (
          auxiliary_inputs[:, 1] if use_auxiliary_inputs else
          np.array([]).reshape(batch_size, 0).astype(np.float32))

      expected_next_sampling_inputs = np.concatenate(
          (sess_results["output_after_next_inputs_fn"][batch_where_sampling]
           if use_next_inputs_fn else
           sess_results["step_outputs"].rnn_output[batch_where_sampling],
           auxiliary_inputs_to_concat[batch_where_sampling]),
          axis=-1)
      self.assertAllClose(
          sess_results["step_next_inputs"][batch_where_sampling],
          expected_next_sampling_inputs)

      self.assertAllClose(
          sess_results["step_next_inputs"][batch_where_not_sampling],
          np.concatenate(
              (np.squeeze(inputs[batch_where_not_sampling, 1], axis=0),
               auxiliary_inputs_to_concat[batch_where_not_sampling]),
              axis=-1))
    def _testStepWithScheduledOutputTrainingHelper(self, use_next_input_layer):
        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        cell_depth = input_depth
        if use_next_input_layer:
            cell_depth = 6

        with self.test_session() as sess:
            inputs = np.random.randn(batch_size, max_time,
                                     input_depth).astype(np.float32)
            cell = core_rnn_cell.LSTMCell(cell_depth)
            half = constant_op.constant(0.5)

            next_input_layer = None
            if use_next_input_layer:
                next_input_layer = layers_core.Dense(input_depth,
                                                     use_bias=False)

            helper = helper_py.ScheduledOutputTrainingHelper(
                inputs=inputs,
                sequence_length=sequence_length,
                sampling_probability=half,
                time_major=False,
                next_input_layer=next_input_layer)

            my_decoder = basic_decoder.BasicDecoder(
                cell=cell,
                helper=helper,
                initial_state=cell.zero_state(dtype=dtypes.float32,
                                              batch_size=batch_size))

            output_size = my_decoder.output_size
            output_dtype = my_decoder.output_dtype
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(cell_depth,
                                                 tensor_shape.TensorShape([])),
                output_size)
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
                output_dtype)

            (first_finished, first_inputs,
             first_state) = my_decoder.initialize()
            (step_outputs, step_state, step_next_inputs,
             step_finished) = my_decoder.step(constant_op.constant(0),
                                              first_inputs, first_state)

            if use_next_input_layer:
                output_after_next_input_layer = next_input_layer(
                    step_outputs.rnn_output)

            batch_size_t = my_decoder.batch_size

            self.assertTrue(
                isinstance(first_state, core_rnn_cell.LSTMStateTuple))
            self.assertTrue(
                isinstance(step_state, core_rnn_cell.LSTMStateTuple))
            self.assertTrue(
                isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
            self.assertEqual((batch_size, cell_depth),
                             step_outputs[0].get_shape())
            self.assertEqual((batch_size, ), step_outputs[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[1].get_shape())

            sess.run(variables.global_variables_initializer())

            fetches = {
                "batch_size": batch_size_t,
                "first_finished": first_finished,
                "first_inputs": first_inputs,
                "first_state": first_state,
                "step_outputs": step_outputs,
                "step_state": step_state,
                "step_next_inputs": step_next_inputs,
                "step_finished": step_finished
            }
            if use_next_input_layer:
                fetches[
                    "output_after_next_input_layer"] = output_after_next_input_layer

            sess_results = sess.run(fetches)

            self.assertAllEqual([False, False, False, False, True],
                                sess_results["first_finished"])
            self.assertAllEqual([False, False, False, True, True],
                                sess_results["step_finished"])

            sample_ids = sess_results["step_outputs"].sample_id
            batch_where_not_sampling = np.where(np.logical_not(sample_ids))
            batch_where_sampling = np.where(sample_ids)
            if use_next_input_layer:
                self.assertAllClose(
                    sess_results["step_next_inputs"][batch_where_sampling],
                    sess_results["output_after_next_input_layer"]
                    [batch_where_sampling])
            else:
                self.assertAllClose(
                    sess_results["step_next_inputs"][batch_where_sampling],
                    sess_results["step_outputs"].
                    rnn_output[batch_where_sampling])
            self.assertAllClose(
                sess_results["step_next_inputs"][batch_where_not_sampling],
                np.squeeze(inputs[batch_where_not_sampling, 1], axis=1))