Esempio n. 1
0
    def inference(self, input_ids, input_lengths, speaker_ids, **kwargs):
        """Call logic."""
        # create input-mask based on input_lengths
        input_mask = tf.sequence_mask(
            input_lengths,
            maxlen=tf.reduce_max(input_lengths),
            name="input_sequence_masks",
        )

        # Encoder Step.
        encoder_hidden_states = self.encoder(
            [input_ids, speaker_ids, input_mask], training=False
        )

        batch_size = tf.shape(encoder_hidden_states)[0]
        alignment_size = tf.shape(encoder_hidden_states)[1]

        # Setup some initial placeholders for decoder step. Include:
        # 1. batch_size for inference.
        # 2. alignment_size for attention size.
        # 3. initial state for decoder cell.
        # 4. memory (encoder hidden state) for attention mechanism.
        # 5. window front/back to solve long sentence synthesize problems. (call after setup memory.)
        self.decoder.sampler.set_batch_size(batch_size)
        self.decoder.cell.set_alignment_size(alignment_size)
        self.decoder.setup_decoder_init_state(
            self.decoder.cell.get_initial_state(batch_size)
        )
        self.decoder.cell.attention_layer.setup_memory(
            memory=encoder_hidden_states,
            memory_sequence_length=input_lengths,  # use for mask attention.
        )
        if self.use_window_mask:
            self.decoder.cell.attention_layer.setup_window(
                win_front=self.win_front, win_back=self.win_back
            )

        # run decode step.
        (
            (frames_prediction, stop_token_prediction, _),
            final_decoder_state,
            _,
        ) = dynamic_decode(self.decoder, maximum_iterations=self.maximum_iterations)

        decoder_outputs = tf.reshape(
            frames_prediction, [batch_size, -1, self.config.n_mels]
        )
        stop_token_predictions = tf.reshape(stop_token_prediction, [batch_size, -1])

        residual = self.postnet(decoder_outputs, training=False)
        residual_projection = self.post_projection(residual)

        mel_outputs = decoder_outputs + residual_projection

        alignment_historys = tf.transpose(
            final_decoder_state.alignment_history.stack(), [1, 2, 0]
        )

        return decoder_outputs, mel_outputs, stop_token_predictions, alignment_historys
Esempio n. 2
0
    def call(
        self,
        input_ids,
        input_lengths,
        speaker_ids,
        mel_gts,
        mel_lengths,
        maximum_iterations=2000,
        use_window_mask=False,
        win_front=2,
        win_back=3,
        training=False,
        **kwargs,
    ):
        """Call logic."""
        # create input-mask based on input_lengths
        input_mask = tf.sequence_mask(
            input_lengths,
            maxlen=tf.reduce_max(input_lengths),
            name="input_sequence_masks",
        )

        # Encoder Step.
        encoder_hidden_states = self.encoder(
            [input_ids, speaker_ids, input_mask], training=training)

        batch_size = tf.shape(encoder_hidden_states)[0]
        alignment_size = tf.shape(encoder_hidden_states)[1]

        # Setup some initial placeholders for decoder step. Include:
        # 1. mel_gts, mel_lengths for teacher forcing mode.
        # 2. alignment_size for attention size.
        # 3. initial state for decoder cell.
        # 4. memory (encoder hidden state) for attention mechanism.
        self.decoder.sampler.setup_target(targets=mel_gts,
                                          mel_lengths=mel_lengths)
        self.decoder.cell.set_alignment_size(alignment_size)
        self.decoder.setup_decoder_init_state(
            self.decoder.cell.get_initial_state(batch_size))
        self.decoder.cell.attention_layer.setup_memory(
            memory=encoder_hidden_states,
            memory_sequence_length=input_lengths,  # use for mask attention.
        )
        if use_window_mask:
            self.decoder.cell.attention_layer.setup_window(win_front=win_front,
                                                           win_back=win_back)

        # run decode step.
        (
            (frames_prediction, stop_token_prediction, _),
            final_decoder_state,
            _,
        ) = dynamic_decode(
            self.decoder,
            maximum_iterations=maximum_iterations,
            enable_tflite_convertible=self.enable_tflite_convertible,
        )

        decoder_outputs = tf.reshape(frames_prediction,
                                     [batch_size, -1, self.config.n_mels])
        stop_token_prediction = tf.reshape(stop_token_prediction,
                                           [batch_size, -1])

        residual = self.postnet(decoder_outputs, training=training)
        residual_projection = self.post_projection(residual)

        mel_outputs = decoder_outputs + residual_projection

        if self.enable_tflite_convertible:
            mask = tf.math.not_equal(
                tf.cast(tf.reduce_sum(tf.abs(decoder_outputs), axis=-1),
                        dtype=tf.int32),
                0,
            )
            decoder_outputs = tf.expand_dims(tf.boolean_mask(
                decoder_outputs, mask),
                                             axis=0)
            mel_outputs = tf.expand_dims(tf.boolean_mask(mel_outputs, mask),
                                         axis=0)
            alignment_history = ()
        else:
            alignment_history = tf.transpose(
                final_decoder_state.alignment_history.stack(), [1, 2, 0])

        return decoder_outputs, mel_outputs, stop_token_prediction, alignment_history