Exemplo n.º 1
0
    def inference(self, input_ids, input_lengths, speaker_ids):
        """Call logic."""
        # create input-mask based on input_lengths
        input_mask = tf.sequence_mask(
            input_lengths,
            maxlen=tf.reduce_max(input_lengths),
            name="input_sequence_masks",
        )

        # Encoder Step.
        encoder_hidden_states = self.encoder(
            [input_ids, speaker_ids, input_mask], training=False
        )

        batch_size = tf.shape(encoder_hidden_states)[0]
        alignment_size = tf.shape(encoder_hidden_states)[1]

        # Setup some initial placeholders for decoder step. Include:
        # 1. batch_size for inference.
        # 2. alignment_size for attention size.
        # 3. initial state for decoder cell.
        # 4. memory (encoder hidden state) for attention mechanism.
        # 5. window front/back to solve long sentence synthesize problems. (call after setup memory.)
        self.decoder.sampler.set_batch_size(batch_size)
        self.decoder.cell.set_alignment_size(alignment_size)
        self.decoder.setup_decoder_init_state(
            self.decoder.cell.get_initial_state(batch_size)
        )
        self.decoder.cell.attention_layer.setup_memory(
            memory=encoder_hidden_states,
            memory_sequence_length=input_lengths,  # use for mask attention.
        )
        if self.use_window_mask:
            self.decoder.cell.attention_layer.setup_window(
                win_front=self.win_front, win_back=self.win_back
            )

        # run decode step.
        (
            (frames_prediction, stop_token_prediction, _),
            final_decoder_state,
            _,
        ) = dynamic_decode(self.decoder, maximum_iterations=self.maximum_iterations)

        decoder_output = tf.reshape(
            frames_prediction, [batch_size, -1, self.config.n_mels]
        )
        stop_token_prediction = tf.reshape(stop_token_prediction, [batch_size, -1])

        residual = self.postnet(decoder_output, training=False)
        residual_projection = self.post_projection(residual)

        mel_outputs = decoder_output + residual_projection

        alignment_history = tf.transpose(
            final_decoder_state.alignment_history.stack(), [1, 2, 0]
        )

        return decoder_output, mel_outputs, stop_token_prediction, alignment_history
Exemplo n.º 2
0
    def call(self,
             input_ids,
             input_lengths,
             speaker_ids,
             mel_outputs,
             mel_lengths,
             maximum_iterations=tf.constant(2000, tf.int32),
             use_window_mask=False,
             win_front=2,
             win_back=3,
             training=False):
        """Call logic."""
        # create input-mask based on input_lengths
        input_mask = tf.sequence_mask(input_lengths,
                                      maxlen=tf.reduce_max(input_lengths),
                                      name='input_sequence_masks')

        # Encoder Step.
        encoder_hidden_states = self.encoder([input_ids, speaker_ids, input_mask], training=training)

        batch_size = tf.shape(encoder_hidden_states)[0]
        alignment_size = tf.shape(encoder_hidden_states)[1]

        # Setup some initial placeholders for decoder step. Include:
        # 1. mel_outputs, mel_lengths for teacher forcing mode.
        # 2. alignment_size for attention size.
        # 3. initial state for decoder cell.
        # 4. memory (encoder hidden state) for attention mechanism.
        self.decoder.sampler.setup_target(targets=mel_outputs, mel_lengths=mel_lengths)
        self.decoder.cell.set_alignment_size(alignment_size)
        self.decoder.setup_decoder_init_state(
            self.decoder.cell.get_initial_state(batch_size)
        )
        self.decoder.cell.attention_layer.setup_memory(
            memory=encoder_hidden_states,
            memory_sequence_length=input_lengths  # use for mask attention.
        )
        if use_window_mask:
            self.decoder.cell.attention_layer.setup_window(win_front=win_front, win_back=win_back)

        # run decode step.
        (frames_prediction, stop_token_prediction, _), final_decoder_state, _ = dynamic_decode(
            self.decoder,
            maximum_iterations=maximum_iterations
        )

        decoder_output = tf.reshape(frames_prediction, [batch_size, -1, self.config.n_mels])
        stop_token_prediction = tf.reshape(stop_token_prediction, [batch_size, -1])

        residual = self.postnet(decoder_output, training=training)
        residual_projection = self.post_projection(residual)

        mel_outputs = decoder_output + residual_projection

        alignment_history = tf.transpose(final_decoder_state.alignment_history.stack(), [1, 2, 0])

        return decoder_output, mel_outputs, stop_token_prediction, alignment_history
Exemplo n.º 3
0
  def call(self, inputs, training=None, mask=None):
    dec_emb_fn = lambda ids: self.embed(ids)
    if self.is_infer:
      enc_outputs, enc_state, enc_seq_len = inputs
      batch_size = tf.shape(enc_outputs)[0]
      helper = seq2seq.GreedyEmbeddingHelper(
          embedding=dec_emb_fn,
          start_tokens=tf.fill([batch_size], self.dec_start_id),
          end_token=self.dec_end_id)
    else:
      dec_inputs, dec_seq_len, enc_outputs, enc_state, \
      enc_seq_len = inputs
      batch_size = tf.shape(enc_outputs)[0]
      dec_inputs = self.embed(dec_inputs)
      helper = seq2seq.TrainingHelper(
          inputs=dec_inputs, sequence_length=dec_seq_len)

    if self.is_infer and self.beam_size > 1:
      tiled_enc_outputs = seq2seq.tile_batch(
          enc_outputs, multiplier=self.beam_size)
      tiled_seq_len = seq2seq.tile_batch(enc_seq_len, multiplier=self.beam_size)
      attn_mech = self._build_attention(
          enc_outputs=tiled_enc_outputs, enc_seq_len=tiled_seq_len)
      dec_cell = seq2seq.AttentionWrapper(self.cell, attn_mech)
      tiled_enc_last_state = seq2seq.tile_batch(
          enc_state, multiplier=self.beam_size)
      tiled_dec_init_state = dec_cell.zero_state(
          batch_size=batch_size * self.beam_size, dtype=tf.float32)
      if self.initial_decode_state:
        tiled_dec_init_state = tiled_dec_init_state.clone(
            cell_state=tiled_enc_last_state)

      dec = seq2seq.BeamSearchDecoder(
          cell=dec_cell,
          embedding=dec_emb_fn,
          start_tokens=tf.tile([self.dec_start_id], [batch_size]),
          end_token=self.dec_end_id,
          initial_state=tiled_dec_init_state,
          beam_width=self.beam_size,
          output_layer=tf.layers.Dense(self.vocab_size),
          length_penalty_weight=self.length_penalty)
    else:
      attn_mech = self._build_attention(
          enc_outputs=enc_outputs, enc_seq_len=enc_seq_len)
      dec_cell = seq2seq.AttentionWrapper(
          cell=self.cell, attention_mechanism=attn_mech)
      dec_init_state = dec_cell.zero_state(
          batch_size=batch_size, dtype=tf.float32)
      if self.initial_decode_state:
        dec_init_state = dec_init_state.clone(cell_state=enc_state)
      dec = seq2seq.BasicDecoder(
          cell=dec_cell,
          helper=helper,
          initial_state=dec_init_state,
          output_layer=tf.layers.Dense(self.vocab_size))
    if self.is_infer:
      dec_outputs, _, _ = \
        seq2seq.dynamic_decode(decoder=dec,
                               maximum_iterations=self.max_dec_len,
                               swap_memory=self.swap_memory,
                               output_time_major=self.time_major)
      return dec_outputs.predicted_ids[:, :, 0]
    else:
      dec_outputs, _, _ = \
        seq2seq.dynamic_decode(decoder=dec,
                               maximum_iterations=tf.reduce_max(dec_seq_len),
                               swap_memory=self.swap_memory,
                               output_time_major=self.time_major)
    return dec_outputs.rnn_output