Ejemplo n.º 1
0
        def inference(inputs, input_lengths):
            """Call logic."""

            # Encoder Step.
            input_lengths = tf.squeeze(input_lengths, -1)

            if self.mel_layer is not None:
                inputs = self.mel_layer(inputs)
            encoder_hidden_states1, encoder_hidden_states2, encoder_hidden_states3 = self.encoder(
                inputs, training=False)
            encoder_hidden_states = tf.concat([
                encoder_hidden_states1, encoder_hidden_states2,
                encoder_hidden_states3
            ], -1)
            batch_size = tf.shape(encoder_hidden_states)[0]
            alignment_size = tf.shape(encoder_hidden_states)[1]
            ctc3_output = self.fc3(encoder_hidden_states3)
            phone_decode = tf.keras.backend.ctc_decode(
                tf.nn.softmax(ctc3_output,
                              -1), input_length=input_lengths)[0][0]
            # Setup some initial placeholders for decoder step. Include:
            # 1. batch_size for inference.
            # 2. alignment_size for attention size.
            # 3. initial state for decoder cell.
            # 4. memory (encoder hidden state) for attention mechanism.
            # 5. window front/back to solve long sentence synthesize problems. (call after setup memory.)
            self.decoder.sampler.set_batch_size(batch_size)
            self.decoder.cell.set_alignment_size(alignment_size)
            # self.setup_maximum_iterations(alignment_size)
            self.decoder.setup_decoder_init_state(
                self.decoder.cell.get_initial_state(batch_size))
            self.decoder.cell.attention_layer.setup_memory(
                memory=encoder_hidden_states,
                memory_sequence_length=input_lengths,  # use for mask attention.
            )
            if self.use_window_mask:
                self.decoder.cell.attention_layer.setup_window(
                    win_front=self.win_front, win_back=self.win_back)

            (
                (classes_prediction, stop_token_prediction, _),
                final_decoder_state,
                _,
            ) = dynamic_decode(self.decoder,
                               maximum_iterations=self.maximum_iterations)

            bert_output = tf.reshape(classes_prediction, [batch_size, -1, 768])
            stop_token_prediction = tf.reshape(stop_token_prediction,
                                               [batch_size, -1])
            decoder_output = self.decoder_project(bert_output)

            decoder_output = self.token_project(decoder_output)
            final_decoded = self.fc_final(decoder_output)

            alignment_history = tf.transpose(
                final_decoder_state.alignment_history.stack(), [1, 2, 0])
            final_decoded = tf.argmax(final_decoded, -1)
            return [final_decoded, phone_decode]
Ejemplo n.º 2
0
        def inference(inputs, input_lengths):
            """Call logic."""

            # Encoder Step.
            input_lengths = tf.squeeze(input_lengths, -1)

            if self.wav_info:
                wav = inputs
            if self.mel_layer is not None:
                inputs = self.mel_layer(inputs)
            if self.wav_info:
                encoder_hidden_states = self.encoder([inputs, wav],
                                                     training=False)[-1]
            else:
                encoder_hidden_states = self.encoder(inputs,
                                                     training=False)[-1]
            batch_size = tf.shape(encoder_hidden_states)[0]
            alignment_size = tf.shape(encoder_hidden_states)[1]

            # Setup some initial placeholders for decoder step. Include:
            # 1. batch_size for inference.
            # 2. alignment_size for attention size.
            # 3. initial state for decoder cell.
            # 4. memory (encoder hidden state) for attention mechanism.
            # 5. window front/back to solve long sentence synthesize problems. (call after setup memory.)
            self.decoder.sampler.set_batch_size(batch_size)
            self.decoder.cell.set_alignment_size(alignment_size)
            # self.setup_maximum_iterations(alignment_size)
            self.decoder.setup_decoder_init_state(
                self.decoder.cell.get_initial_state(batch_size))
            self.decoder.cell.attention_layer.setup_memory(
                memory=encoder_hidden_states,
                memory_sequence_length=input_lengths,  # use for mask attention.
            )
            if self.use_window_mask:
                self.decoder.cell.attention_layer.setup_window(
                    win_front=self.win_front, win_back=self.win_back)

            (
                (classes_prediction, stop_token_prediction, _),
                final_decoder_state,
                _,
            ) = dynamic_decode(self.decoder,
                               maximum_iterations=self.maximum_iterations)

            decoder_output = tf.reshape(
                classes_prediction, [batch_size, -1, self.config.n_classes])
            stop_token_prediction = tf.reshape(stop_token_prediction,
                                               [batch_size, -1])

            alignment_history = tf.transpose(
                final_decoder_state.alignment_history.stack(), [1, 2, 0])
            decoder_output = tf.argmax(decoder_output, -1)
            return [decoder_output]
Ejemplo n.º 3
0
    def call(
        self,
        inputs,
        targets=None,
        targets_lengths=None,
        use_window_mask=False,
        win_front=2,
        win_back=3,
        training=False,
    ):
        """Call logic."""
        # Encoder Step.
        # input_lengths=tf.squeeze(input_lengths,-1)
        inputs, input_lengths = inputs
        if self.mel_layer is not None:
            inputs = self.mel_layer(inputs)

        encoder_hidden_states = self.encoder(inputs, training=training)
        batch_size = tf.shape(encoder_hidden_states)[0]
        alignment_size = tf.shape(encoder_hidden_states)[1]

        # Setup some initial placeholders for decoder step. Include:
        # 1. mel_outputs, mel_lengths for teacher forcing mode.
        # 2. alignment_size for attention size.
        # 3. initial state for decoder cell.
        # 4. memory (encoder hidden state) for attention mechanism.
        if targets is not None:
            self.decoder.sampler.setup_target(targets=targets,
                                              targets_lengths=targets_lengths)
        self.decoder.sampler.set_batch_size(batch_size)
        self.decoder.cell.set_alignment_size(alignment_size)

        self.decoder.setup_decoder_init_state(
            self.decoder.cell.get_initial_state(batch_size))
        self.decoder.cell.attention_layer.setup_memory(
            memory=encoder_hidden_states,
            memory_sequence_length=input_lengths,  # use for mask attention.
        )
        if use_window_mask:
            self.decoder.cell.attention_layer.setup_window(win_front=win_front,
                                                           win_back=win_back)

        # run decode step.
        (
            (classes_prediction, stop_token_prediction, _),
            final_decoder_state,
            _,
        ) = dynamic_decode(
            self.decoder,
            maximum_iterations=self.maximum_iterations,
            enable_tflite_convertible=self.enable_tflite_convertible)

        decoder_output = tf.reshape(classes_prediction,
                                    [batch_size, -1, self.config.n_classes])
        stop_token_prediction = tf.reshape(stop_token_prediction,
                                           [batch_size, -1])

        if self.enable_tflite_convertible:
            mask = tf.math.not_equal(
                tf.cast(tf.reduce_sum(tf.abs(decoder_output), axis=-1),
                        dtype=tf.int32), 0)
            decoder_output = tf.expand_dims(tf.boolean_mask(
                decoder_output, mask),
                                            axis=0)
            alignment_history = ()
        else:
            alignment_history = tf.transpose(
                final_decoder_state.alignment_history.stack(), [1, 2, 0])

        return decoder_output, stop_token_prediction, alignment_history
Ejemplo n.º 4
0
    def call(
        self,
        inputs,
        input_lengths,
        targets=None,
        targets_lengths=None,
        use_window_mask=False,
        win_front=2,
        win_back=3,
        training=False,
    ):
        """Call logic."""
        # Encoder Step.
        # input_lengths=tf.squeeze(input_lengths,-1)
        encoder_hidden_states1, encoder_hidden_states2, encoder_hidden_states3 = self.encoder(
            inputs, training=training)
        ctc1_output = self.fc1(encoder_hidden_states1)
        ctc2_output = self.fc2(encoder_hidden_states2)
        ctc3_output = self.fc3(encoder_hidden_states3)
        encoder_hidden_states = tf.concat([
            encoder_hidden_states1, encoder_hidden_states2,
            encoder_hidden_states3
        ], -1)
        batch_size = tf.shape(encoder_hidden_states)[0]
        alignment_size = tf.shape(encoder_hidden_states)[1]

        # Setup some initial placeholders for decoder step. Include:
        # 1. mel_outputs, mel_lengths for teacher forcing mode.
        # 2. alignment_size for attention size.
        # 3. initial state for decoder cell.
        # 4. memory (encoder hidden state) for attention mechanism.
        if targets is not None:
            self.decoder.sampler.setup_target(targets=targets,
                                              targets_lengths=targets_lengths)
        self.decoder.sampler.set_batch_size(batch_size)
        self.decoder.cell.set_alignment_size(alignment_size)

        self.decoder.setup_decoder_init_state(
            self.decoder.cell.get_initial_state(batch_size))
        self.decoder.cell.attention_layer.setup_memory(
            memory=encoder_hidden_states,
            memory_sequence_length=input_lengths,  # use for mask attention.
        )
        if use_window_mask:
            self.decoder.cell.attention_layer.setup_window(win_front=win_front,
                                                           win_back=win_back)

        # run decode step.
        (
            (classes_prediction, stop_token_prediction, _),
            final_decoder_state,
            _,
        ) = dynamic_decode(
            self.decoder,
            maximum_iterations=self.maximum_iterations,
            enable_tflite_convertible=self.enable_tflite_convertible)

        bert_output = tf.reshape(classes_prediction, [batch_size, -1, 768])
        stop_token_prediction = tf.reshape(stop_token_prediction,
                                           [batch_size, -1])
        decoder_output = self.decoder_project(bert_output)
        decoder_output = self.token_project(decoder_output, training=training)
        final_decoded = self.fc_final(decoder_output)

        alignment_history = tf.transpose(
            final_decoder_state.alignment_history.stack(), [1, 2, 0])

        return ctc1_output, ctc2_output, ctc3_output, final_decoded, bert_output, stop_token_prediction, alignment_history