Example #1
0
    def __init__(self,
                 encoder: tf.keras.Model,
                 num_classes: int,
                 speech_config,
                 name="ctc_model",

                 **kwargs):
        super(CtcModel, self).__init__(name=name, **kwargs)
        self.encoder = encoder
        # Fully connected layer
        self.speech_config=speech_config
        self.mel_layer=None
        if speech_config['use_mel_layer']:
            if speech_config['mel_layer_type']=='Melspectrogram':
                self.mel_layer=Melspectrogram(sr=speech_config['sample_rate'],n_mels=speech_config['num_feature_bins'],
                                              n_hop=int(speech_config['stride_ms']*speech_config['sample_rate']//1000),
                                              n_dft=1024,
                                              trainable_fb=speech_config['trainable_kernel']
                                              )
            else:
                self.mel_layer = Spectrogram(
                                                n_hop=int(speech_config['stride_ms'] * speech_config['sample_rate']//1000),
                                                n_dft=1024,
                                                trainable_kernel=speech_config['trainable_kernel']
                                                )
            self.mel_layer.trainable=speech_config['trainable_kernel']
        self.wav_info=speech_config['add_wav_info']
        if self.wav_info:
            assert speech_config['use_mel_layer']==True,'shold set use_mel_layer is True'

        self.fc = tf.keras.layers.TimeDistributed(
            tf.keras.layers.Dense(units=num_classes, activation="linear",
                                  use_bias=True), name="fully_connected")
        self.recognize_pb=None
Example #2
0
    def __init__(self,
                 encoder1,
                 encoder2,
                 encoder3,
                 classes1,
                 classes2,
                 classes3,
                 dmodel,
                 speech_config=dict,
                 **kwargs):
        super().__init__(self, **kwargs)
        self.encoder1 = encoder1
        self.encoder2 = encoder2
        self.encoder3 = encoder3
        self.speech_config = speech_config
        self.mel_layer = None
        if speech_config['use_mel_layer']:
            if speech_config['mel_layer_type'] == 'Melspectrogram':
                self.mel_layer = Melspectrogram(
                    sr=speech_config['sample_rate'],
                    n_mels=speech_config['num_feature_bins'],
                    n_hop=int(speech_config['stride_ms'] *
                              speech_config['sample_rate'] // 1000),
                    n_dft=1024,
                    trainable_fb=speech_config['trainable_kernel'])
            else:
                self.mel_layer = Spectrogram(
                    n_hop=int(speech_config['stride_ms'] *
                              speech_config['sample_rate'] // 1000),
                    n_dft=1024,
                    trainable_kernel=speech_config['trainable_kernel'])
            self.mel_layer.trainable = speech_config['trainable_kernel']
        self.wav_info = speech_config['add_wav_info']
        if self.wav_info:
            assert speech_config[
                'use_mel_layer'] == True, 'shold set use_mel_layer is True'
        self.fc1 = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(
            units=classes1, activation="linear", use_bias=True),
                                                   name="fully_connected1")

        self.fc2 = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(
            units=classes2, activation="linear", use_bias=True),
                                                   name="fully_connected2")

        self.fc3 = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(
            units=classes3, activation="linear", use_bias=True),
                                                   name="fully_connected3")

        self.fc_to_project_1 = tf.keras.layers.Dense(
            dmodel, name='word_prob_projector')
        self.fc_to_project_2 = tf.keras.layers.Dense(
            dmodel, name='phone_prob_projector')
        self.fc_to_project_3 = tf.keras.layers.Dense(dmodel,
                                                     name='py_prob_projector')
        self.fc_final_class = tf.keras.layers.Conv1D(classes3,
                                                     32,
                                                     padding='same',
                                                     name="cnn_final_class")
Example #3
0
    def __init__(self,
                 encoder: tf.keras.Model,
                 vocabulary_size: int,
                 embed_dim: int = 512,
                 embed_dropout: float = 0,
                 num_lstms: int = 1,
                 lstm_units: int = 320,
                 joint_dim: int = 1024,
                 name="transducer",
                 speech_config=dict,
                 **kwargs):
        super(Transducer, self).__init__(name=name, **kwargs)
        self.encoder = encoder
        self.predict_net = TransducerPrediction(
            vocabulary_size=vocabulary_size,
            embed_dim=embed_dim,
            embed_dropout=embed_dropout,
            num_lstms=num_lstms,
            lstm_units=lstm_units,
            name=f"{name}_prediction")
        self.joint_net = TransducerJoint(vocabulary_size=vocabulary_size,
                                         joint_dim=joint_dim,
                                         name=f"{name}_joint")
        self.speech_config = speech_config
        self.mel_layer = None
        if speech_config['use_mel_layer']:
            if speech_config['mel_layer_type'] == 'Melspectrogram':
                self.mel_layer = Melspectrogram(
                    sr=speech_config['sample_rate'],
                    n_mels=speech_config['num_feature_bins'],
                    n_hop=int(speech_config['stride_ms'] *
                              speech_config['sample_rate'] // 1000),
                    n_dft=1024,
                    trainable_fb=speech_config['trainable_kernel'])
            else:
                self.mel_layer = Spectrogram(
                    n_hop=int(speech_config['stride_ms'] *
                              speech_config['sample_rate'] // 1000),
                    n_dft=1024,
                    trainable_kernel=speech_config['trainable_kernel'])
            self.mel_layer.trainable = speech_config['trainable_kernel']

        self.ctc_classes = tf.keras.layers.Dense(vocabulary_size,
                                                 name='ctc_classes')
        self.wav_info = speech_config['add_wav_info']
        if self.wav_info:
            assert speech_config[
                'use_mel_layer'] == True, 'shold set use_mel_layer is True'
        self.kept_decode = None
        self.startid = 0
        self.endid = 1
        self.max_iter = 10
Example #4
0
    def __init__(self, encoder, config, training, enable_tflite_convertible=False,speech_config=dict, **kwargs):
        super().__init__(self, **kwargs)
        self.encoder = encoder
        self.decoder_cell = DecoderCell(
            config, training=training, name="decoder_cell",
            enable_tflite_convertible=enable_tflite_convertible
        )
        self.decoder = LASDecoder(
            self.decoder_cell,
            TrainingSampler(config) if training is True else TestingSampler(config),
            enable_tflite_convertible=enable_tflite_convertible
        )
        self.config = config
        self.speech_config = speech_config
        self.mel_layer = None
        if speech_config['use_mel_layer']:
            if speech_config['mel_layer_type'] == 'Melspectrogram':
                self.mel_layer = Melspectrogram(sr=speech_config['sample_rate'],
                                                n_mels=speech_config['num_feature_bins'],
                                                n_hop=int(speech_config['stride_ms'] * speech_config['sample_rate']//1000),
                                                n_dft=1024,
                                                trainable_fb=speech_config['trainable_kernel']
                                                )
            else:
                self.mel_layer = Spectrogram(
                                             n_hop=int(speech_config['stride_ms'] * speech_config['sample_rate']//1000),
                                             n_dft=1024,
                                             trainable_kernel=speech_config['trainable_kernel']
                                             )


        self.use_window_mask = False
        self.maximum_iterations = 1000 if training else 50
        self.enable_tflite_convertible = enable_tflite_convertible
class Transducer(tf.keras.Model):
    """ Transducer Model Warper """
    def __init__(self,
                 encoder: tf.keras.Model,
                 vocabulary_size: int,
                 embed_dim: int = 512,
                 embed_dropout: float = 0,
                 num_lstms: int = 1,
                 lstm_units: int = 320,
                 joint_dim: int = 1024,
                 name="transducer",
                 speech_config=dict,
                 **kwargs):
        super(Transducer, self).__init__(name=name, **kwargs)
        self.encoder = encoder
        self.predict_net = TransducerPrediction(
            vocabulary_size=vocabulary_size,
            embed_dim=embed_dim,
            embed_dropout=embed_dropout,
            num_lstms=num_lstms,
            lstm_units=lstm_units,
            name=f"{name}_prediction")
        self.joint_net = TransducerJoint(vocabulary_size=vocabulary_size,
                                         joint_dim=joint_dim,
                                         name=f"{name}_joint")
        self.speech_config = speech_config
        self.mel_layer = None
        if speech_config['use_mel_layer']:
            if speech_config['mel_layer_type'] == 'Melspectrogram':
                self.mel_layer = Melspectrogram(
                    sr=speech_config['sample_rate'],
                    n_mels=speech_config['num_feature_bins'],
                    n_hop=int(speech_config['stride_ms'] *
                              speech_config['sample_rate'] // 1000),
                    n_dft=1024,
                    trainable_fb=speech_config['trainable_kernel'])
            else:
                self.mel_layer = Spectrogram(
                    n_hop=int(speech_config['stride_ms'] *
                              speech_config['sample_rate'] // 1000),
                    n_dft=1024,
                    trainable_kernel=speech_config['trainable_kernel'])
        self.kept_decode = None
        self.startid = 0
        self.endid = 1
        self.max_iter = 10

    def _build(self, sample_shape):  # Call on real data for building model
        features = tf.random.normal(shape=sample_shape)
        predicted = tf.constant([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
        return self([features, predicted], training=True)

    def save_seperate(self, path_to_dir: str):
        self.encoder.save(os.path.join(path_to_dir, "encoder"))
        self.predict_net.save(os.path.join(path_to_dir, "prediction"))
        self.joint_net.save(os.path.join(path_to_dir, "joint"))

    def summary(self, line_length=None, **kwargs):
        self.encoder.summary(line_length=line_length, **kwargs)
        self.predict_net.summary(line_length=line_length, **kwargs)
        self.joint_net.summary(line_length=line_length, **kwargs)
        super(Transducer, self).summary(line_length=line_length, **kwargs)

    # @tf.function(experimental_relax_shapes=True)
    def call(self, inputs, training=False):
        features, predicted = inputs

        if self.mel_layer is not None:
            features = self.mel_layer(features)
        enc = self.encoder(features, training=training)
        pred = self.predict_net(predicted, training=training)
        outputs = self.joint_net([enc, pred], training=training)

        return outputs

    def add_featurizers(self, text_featurizer: TextFeaturizer):

        self.text_featurizer = text_featurizer

    def return_pb_function(self, shape):
        @tf.function(input_signature=[
            tf.TensorSpec(shape, dtype=tf.float32),  # features
            tf.TensorSpec([None, 1], dtype=tf.int32),  # features
        ])
        def recognize_pb(features, lengths):
            b_i = tf.constant(0, dtype=tf.int32)

            B = tf.shape(features)[0]

            decoded = tf.constant([], dtype=tf.int32)

            def _cond(b_i, B, features, decoded):
                return tf.less(b_i, B)

            def _body(b_i, B, features, decoded):
                yseq = self.perform_greedy(tf.expand_dims(features[b_i],
                                                          axis=0),
                                           streaming=False)

                yseq = tf.concat([
                    yseq,
                    tf.constant([[self.text_featurizer.stop]], tf.int32)
                ],
                                 axis=-1)
                decoded = tf.concat([decoded, yseq[0]], axis=0)
                return b_i + 1, B, features, decoded

            _, _, _, decoded = tf.while_loop(
                _cond,
                _body,
                loop_vars=(b_i, B, features, decoded),
                shape_invariants=(tf.TensorShape([]), tf.TensorShape([]),
                                  get_shape_invariants(features),
                                  tf.TensorShape([None])))

            return [decoded]

        self.recognize_pb = recognize_pb

    @tf.function(experimental_relax_shapes=True)
    def perform_greedy(self, features, streaming=False):
        if self.mel_layer is not None:
            features = self.mel_layer(features)

        decoded = tf.constant([self.text_featurizer.start])
        if self.kept_decode is not None:
            decoded = self.kept_decode

        enc = self.encoder(features, training=False)  # [1, T, E]
        enc = tf.squeeze(enc, axis=0)  # [T, E]

        T = tf.cast(tf.shape(enc)[0], dtype=tf.int32)

        i = tf.constant(0, dtype=tf.int32)

        def _cond(enc, i, decoded, T):
            return tf.less(i, T)

        def _body(enc, i, decoded, T):
            hi = tf.reshape(enc[i], [1, 1, -1])  # [1, 1, E]
            y = self.predict_net(
                inputs=tf.reshape(decoded, [1, -1]),  # [1, 1]
                p_memory_states=None,
                training=False)
            y = y[:, -1:]
            # [1, 1, P], [1, P], [1, P]
            # [1, 1, E] + [1, 1, P] => [1, 1, 1, V]
            ytu = tf.nn.log_softmax(self.joint_net([hi, y], training=False))
            ytu = tf.squeeze(ytu, axis=None)  # [1, 1, 1, V] => [V]
            n_predict = tf.argmax(ytu, axis=-1,
                                  output_type=tf.int32)  # => argmax []
            n_predict = tf.reshape(n_predict, [1])

            def return_no_blank():
                return tf.concat([decoded, n_predict], axis=0)

            decoded = tf.cond(n_predict != self.text_featurizer.blank
                              and n_predict != 0,
                              true_fn=return_no_blank,
                              false_fn=lambda: decoded)

            return enc, i + 1, decoded, T

        _, _, decoded, _ = tf.while_loop(
            _cond,
            _body,
            loop_vars=(enc, i, decoded, T),
            shape_invariants=(tf.TensorShape([None, None]), tf.TensorShape([]),
                              tf.TensorShape([None]), tf.TensorShape([])))

        return tf.expand_dims(decoded, axis=0)

    def recognize(self, features):
        decoded = self.perform_greedy(features)

        return decoded

    def get_config(self):
        if self.mel_layer is not None:
            conf = self.mel_layer.get_config()
            conf.update(self.encoder.get_config())
        else:
            conf = self.encoder.get_config()
        conf.update(self.predict_net.get_config())
        conf.update(self.joint_net.get_config())
        return conf
Example #6
0
class CtcModel(tf.keras.Model):
    def __init__(self,
                 encoder: tf.keras.Model,
                 num_classes: int,
                 speech_config,
                 name="ctc_model",
                 **kwargs):
        super(CtcModel, self).__init__(name=name, **kwargs)
        self.encoder = encoder
        # Fully connected layer
        self.dmodel = encoder.dmodel
        self.speech_config = speech_config
        self.mel_layer = None
        if speech_config['use_mel_layer']:
            if speech_config['mel_layer_type'] == 'Melspectrogram':
                self.mel_layer = Melspectrogram(
                    sr=speech_config['sample_rate'],
                    n_mels=speech_config['num_feature_bins'],
                    n_hop=int(speech_config['stride_ms'] *
                              speech_config['sample_rate'] // 1000),
                    n_dft=1024,
                    trainable_fb=speech_config['trainable_kernel'])
            else:
                self.mel_layer = Spectrogram(
                    n_hop=int(speech_config['stride_ms'] *
                              speech_config['sample_rate'] // 1000),
                    n_dft=1024,
                    trainable_kernel=speech_config['trainable_kernel'])
            self.mel_layer.trainable = speech_config['trainable_kernel']
        self.wav_info = speech_config['add_wav_info']
        self.chunk_size = int(self.speech_config['sample_rate'] *
                              self.speech_config['streaming_bucket'])
        self.streaming = self.speech_config['streaming']
        if self.wav_info:
            assert speech_config[
                'use_mel_layer'] == True, 'shold set use_mel_layer is True'

        self.fc = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(
            units=num_classes, activation="linear", use_bias=True),
                                                  name="fully_connected")
        self.decode_layer = ConformerBlock(self.dmodel,
                                           self.encoder.dropout,
                                           self.encoder.fc_factor,
                                           self.encoder.head_size,
                                           self.encoder.num_heads,
                                           name='decode_conformer_block')
        self.recognize_pb = None
        self.encoder.add_chunk_size(
            self.chunk_size, speech_config['num_feature_bins'],
            int(speech_config['stride_ms'] * speech_config['sample_rate'] //
                1000))

    def _build(self, shape=None):
        shape = [1, self.chunk_size * 3, 1]
        inputs = np.random.normal(size=shape).astype(np.float32)
        self(inputs)

    def summary(self, line_length=None, **kwargs):
        self.encoder.summary(line_length=line_length, **kwargs)
        super(CtcModel, self).summary(line_length, **kwargs)

    def add_featurizers(
        self,
        text_featurizer: TextFeaturizer,
    ):

        self.text_featurizer = text_featurizer

    def call(self, x, training=False, **kwargs):
        features = x
        if self.mel_layer is not None:
            if self.wav_info:
                wav = features

                features = self.mel_layer(features)

            else:
                features = self.mel_layer(features)

        if self.wav_info:
            enc = self.encoder([features, wav], training=training)
        else:
            enc = self.encoder(features, training=training)

        chunk_outputs = self.fc(enc, training=training)
        decode_enc = self.decode_layer(enc, training=training)
        decode_outputs = self.fc(decode_enc, training=training)
        return chunk_outputs, decode_outputs

    @tf.function(
        experimental_relax_shapes=True,
        input_signature=[
            tf.TensorSpec([None, None, 1], dtype=tf.float32),
        ],
    )
    def extract_feature(self, inputs):

        features = inputs
        if self.mel_layer is not None:
            if self.wav_info:
                wav = features
                features = self.mel_layer(features)

            else:
                features = self.mel_layer(features)

        if self.wav_info:
            enc = self.encoder.inference([features, wav], training=False)
        else:
            enc = self.encoder.inference(features, training=False)

        return enc

    @tf.function(
        experimental_relax_shapes=True,
        input_signature=[
            tf.TensorSpec([None, None, 256], dtype=tf.float32),
            tf.TensorSpec([None, 1], dtype=tf.int32),
        ],
    )
    def ctc_decode(self, enc_outputs, length):
        enc = self.decode_layer(enc_outputs, training=False)
        ctc_outputs = self.fc(enc, training=False)

        probs = tf.nn.softmax(ctc_outputs)
        decoded = tf.keras.backend.ctc_decode(y_pred=probs,
                                              input_length=tf.squeeze(
                                                  length, -1),
                                              greedy=True)[0][0]
        return decoded

    def return_pb_function(self, shape):
        @tf.function(experimental_relax_shapes=True,
                     input_signature=[
                         tf.TensorSpec(shape, dtype=tf.float32),
                         tf.TensorSpec([None, 1], dtype=tf.int32),
                     ])
        def recognize_function(features, length):
            _, logits = self(features)
            probs = tf.nn.softmax(logits)
            decoded = tf.keras.backend.ctc_decode(y_pred=probs,
                                                  input_length=tf.squeeze(
                                                      length, -1),
                                                  greedy=True)[0][0]
            return decoded

        self.recognize_pb = recognize_function

    def get_config(self):
        if self.mel_layer is not None:
            config = self.mel_layer.get_config()
            config.update(self.encoder.get_config())
        else:
            config = self.encoder.get_config()
        config.update(self.decode_layer.get_config())
        config.update(self.fc.get_config())
        return config
Example #7
0
class CtcModel(tf.keras.Model):
    def __init__(self,
                 encoder: tf.keras.Model,
                 num_classes: int,
                 speech_config,
                 name="ctc_model",
                 **kwargs):
        super(CtcModel, self).__init__(name=name, **kwargs)
        self.encoder = encoder
        # Fully connected layer
        self.speech_config = speech_config
        self.mel_layer = None
        if speech_config['use_mel_layer']:
            if speech_config['mel_layer_type'] == 'Melspectrogram':
                self.mel_layer = Melspectrogram(
                    sr=speech_config['sample_rate'],
                    n_mels=speech_config['num_feature_bins'],
                    n_hop=int(speech_config['stride_ms'] *
                              speech_config['sample_rate'] // 1000),
                    n_dft=1024,
                    trainable_fb=speech_config['trainable_kernel'])
            else:
                self.mel_layer = Spectrogram(
                    n_hop=int(speech_config['stride_ms'] *
                              speech_config['sample_rate'] // 1000),
                    n_dft=1024,
                    trainable_kernel=speech_config['trainable_kernel'])
            self.mel_layer.trainable = speech_config['trainable_kernel']
        self.wav_info = speech_config['add_wav_info']
        if self.wav_info:
            assert speech_config[
                'use_mel_layer'] == True, 'shold set use_mel_layer is True'

        self.fc = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(
            units=num_classes, activation="linear", use_bias=True),
                                                  name="fully_connected")
        self.recognize_pb = None

    def _build(self, sample_shape):
        features = tf.random.normal(shape=sample_shape)
        self(features, training=False)

    def summary(self, line_length=None, **kwargs):
        self.encoder.summary(line_length=line_length, **kwargs)
        super(CtcModel, self).summary(line_length, **kwargs)

    def add_featurizers(
        self,
        text_featurizer: TextFeaturizer,
    ):

        self.text_featurizer = text_featurizer

    # @tf.function(experimental_relax_shapes=True)
    def call(self, inputs, training=False, **kwargs):
        if self.mel_layer is not None:
            if self.wav_info:
                wav = inputs
                inputs = self.mel_layer(inputs)
            else:
                inputs = self.mel_layer(inputs)
            # print(inputs.shape)
        if self.wav_info:
            outputs = self.encoder([inputs, wav], training=training)
        else:
            outputs = self.encoder(inputs, training=training)
        outputs = self.fc(outputs, training=training)
        return outputs

    def return_pb_function(self, shape, beam=False):
        @tf.function(experimental_relax_shapes=True,
                     input_signature=[
                         tf.TensorSpec(shape, dtype=tf.float32),
                         tf.TensorSpec([None, 1], dtype=tf.int32),
                     ])
        def recognize_tflite(features, length):

            logits = self.call(features, training=False)

            probs = tf.nn.softmax(logits)
            decoded = tf.keras.backend.ctc_decode(y_pred=probs,
                                                  input_length=tf.squeeze(
                                                      length, -1),
                                                  greedy=True)[0][0]
            return [decoded]

        @tf.function(experimental_relax_shapes=True,
                     input_signature=[
                         tf.TensorSpec(shape, dtype=tf.float32),
                         tf.TensorSpec([None, 1], dtype=tf.int32),
                     ])
        def recognize_beam_tflite(features, length):

            logits = self.call(features, training=False)

            probs = tf.nn.softmax(logits)
            decoded = tf.keras.backend.ctc_decode(
                y_pred=probs,
                input_length=tf.squeeze(length, -1),
                greedy=False,
                beam_width=self.text_featurizer.decoder_config["beam_width"]
            )[0][0]
            return [decoded]

        self.recognize_pb = recognize_tflite if not beam else recognize_beam_tflite

    def get_config(self):
        if self.mel_layer is not None:
            config = self.mel_layer.get_config()
            config.update(self.encoder.get_config())
        else:
            config = self.encoder.get_config()
        config.update(self.fc.get_config())
        return config
Example #8
0
class Transducer(tf.keras.Model):
    """ Transducer Model Warper """
    def __init__(self,
                 encoder: tf.keras.Model,
                 vocabulary_size: int,
                 embed_dim: int = 512,
                 embed_dropout: float = 0,
                 num_lstms: int = 1,
                 lstm_units: int = 320,
                 joint_dim: int = 1024,
                 name="transducer",
                 speech_config=dict,
                 **kwargs):
        super(Transducer, self).__init__(name=name, **kwargs)
        self.encoder = encoder
        self.predict_net = TransducerPrediction(
            vocabulary_size=vocabulary_size,
            embed_dim=embed_dim,
            embed_dropout=embed_dropout,
            num_lstms=num_lstms,
            lstm_units=lstm_units,
            name=f"{name}_prediction")
        self.joint_net = TransducerJoint(vocabulary_size=vocabulary_size,
                                         joint_dim=joint_dim,
                                         name=f"{name}_joint")
        self.speech_config = speech_config
        self.mel_layer = None
        if speech_config['use_mel_layer']:
            if speech_config['mel_layer_type'] == 'Melspectrogram':
                self.mel_layer = Melspectrogram(
                    sr=speech_config['sample_rate'],
                    n_mels=speech_config['num_feature_bins'],
                    n_hop=int(speech_config['stride_ms'] *
                              speech_config['sample_rate'] // 1000),
                    n_dft=1024,
                    trainable_fb=speech_config['trainable_kernel'])
            else:
                self.mel_layer = Spectrogram(
                    n_hop=int(speech_config['stride_ms'] *
                              speech_config['sample_rate'] // 1000),
                    n_dft=1024,
                    trainable_kernel=speech_config['trainable_kernel'])
        self.kept_hyps = None
        self.startid = 0
        self.endid = 1
        self.max_iter = 10

    def _build(self, sample_shape):  # Call on real data for building model
        features = tf.random.normal(shape=sample_shape)
        predicted = tf.constant([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
        return self([features, predicted], training=True)

    def save_seperate(self, path_to_dir: str):
        self.encoder.save(os.path.join(path_to_dir, "encoder"))
        self.predict_net.save(os.path.join(path_to_dir, "prediction"))
        self.joint_net.save(os.path.join(path_to_dir, "joint"))

    def summary(self, line_length=None, **kwargs):
        self.encoder.summary(line_length=line_length, **kwargs)
        self.predict_net.summary(line_length=line_length, **kwargs)
        self.joint_net.summary(line_length=line_length, **kwargs)
        super(Transducer, self).summary(line_length=line_length, **kwargs)

    # @tf.function(experimental_relax_shapes=True)
    def call(self, inputs, training=False):

        features, predicted = inputs
        if self.mel_layer is not None:
            features = self.mel_layer(features)
        enc = self.encoder(features, training=training)
        pred, _ = self.predict_net(predicted, training=training)
        outputs = self.joint_net([enc, pred], training=training)

        return outputs

    def add_featurizers(self, text_featurizer: TextFeaturizer):

        self.text_featurizer = text_featurizer

    def return_pb_function(self, shape):
        @tf.function(
            experimental_relax_shapes=True,
            input_signature=[
                tf.TensorSpec(shape, dtype=tf.float32),  # features
                tf.TensorSpec([None, 1], dtype=tf.int32),  # features
            ])
        def recognize_pb(features, length, training=False):
            decoded = self.perform_greedy(features)
            return [decoded]

        self.recognize_pb = recognize_pb

    @tf.function(experimental_relax_shapes=True)
    def perform_greedy(self, features):
        batch = tf.shape(features)[0]
        new_hyps = Hypotheses(
            tf.zeros([batch], tf.float32),
            self.text_featurizer.start * tf.ones([batch, 1], dtype=tf.int32),
            self.predict_net.get_initial_state(features))
        if self.mel_layer is not None:
            features = self.mel_layer(features)
        enc = self.encoder(features, training=False)  # [B, T, E]
        # enc = tf.squeeze(enc, axis=0)  # [T, E]
        stop_flag = tf.zeros([batch, 1], tf.float32)
        T = tf.cast(shape_list(enc)[1], dtype=tf.int32)

        i = tf.constant(0, dtype=tf.int32)

        def _cond(enc, i, new_hyps, T, stop_flag):
            return tf.less(i, T)

        def _body(enc, i, new_hyps, T, stop_flag):
            hi = enc[:, i:i + 1]  # [B, 1, E]
            y, n_memory_states = self.predict_net(
                inputs=new_hyps[1][:, -1:],  # [1, 1]
                p_memory_states=new_hyps[2],
                training=False)  # [1, 1, P], [1, P], [1, P]
            # [1, 1, E] + [1, 1, P] => [1, 1, 1, V]
            ytu = tf.nn.log_softmax(self.joint_net([hi, y], training=False))
            ytu = tf.squeeze(ytu, axis=None)  # [B, 1, 1, V] => [B,V]
            n_predict = tf.expand_dims(
                tf.argmax(ytu, axis=-1, output_type=tf.int32),
                -1)  # => argmax []

            # print(stop_flag.shape,n_predict.shape)
            new_hyps = Hypotheses(
                new_hyps[0] + 1,
                tf.concat(
                    [new_hyps[1], tf.reshape(n_predict, [-1, 1])], -1),
                n_memory_states)

            stop_flag += tf.cast(
                tf.equal(tf.reshape(n_predict, [-1, 1]),
                         self.text_featurizer.stop), tf.float32)
            n_i = tf.cond(
                tf.reduce_all(tf.cast(stop_flag, tf.bool)),
                true_fn=lambda: T,
                false_fn=lambda: i + 1,
            )

            return enc, n_i, new_hyps, T, stop_flag

        _, _, new_hyps, _, stop_flag = tf.while_loop(
            _cond,
            _body,
            loop_vars=(enc, i, new_hyps, T, stop_flag),
            shape_invariants=(
                tf.TensorShape([None, None, None]),
                tf.TensorShape([]),
                Hypotheses(
                    tf.TensorShape([None]), tf.TensorShape([None, None]),
                    tf.nest.map_structure(get_shape_invariants, new_hyps[-1])),
                tf.TensorShape([]),
                tf.TensorShape([None, 1]),
            ))

        return new_hyps[1]

    def recognize(self, features):
        decoded = self.perform_greedy(features)

        return decoded

    def get_config(self):
        if self.mel_layer is not None:
            conf = self.mel_layer.get_config()
            conf.update(self.encoder.get_config())
        else:
            conf = self.encoder.get_config()
        conf.update(self.predict_net.get_config())
        conf.update(self.joint_net.get_config())
        return conf
Example #9
0
class Transducer(tf.keras.Model):
    """ Transducer Model Warper """

    def __init__(self,
                 encoder: tf.keras.Model,
                 vocabulary_size: int,
                 embed_dim: int = 512,
                 embed_dropout: float = 0,
                 num_lstms: int = 1,
                 lstm_units: int = 320,
                 joint_dim: int = 1024,
                 name="transducer", speech_config=dict,
                 **kwargs):
        super(Transducer, self).__init__(name=name, **kwargs)
        self.encoder = encoder
        self.num_lstms = num_lstms
        self.predict_net = TransducerPrediction(
            vocabulary_size=vocabulary_size,
            embed_dim=embed_dim,
            embed_dropout=embed_dropout,
            num_lstms=num_lstms,
            lstm_units=lstm_units,
            name=f"{name}_prediction"
        )
        self.joint_net = TransducerJoint(
            vocabulary_size=vocabulary_size,
            joint_dim=joint_dim,
            name=f"{name}_joint"
        )

        self.speech_config = speech_config
        self.mel_layer = None
        if speech_config['use_mel_layer']:
            if speech_config['mel_layer_type'] == 'Melspectrogram':
                self.mel_layer = Melspectrogram(sr=speech_config['sample_rate'],
                                                n_mels=speech_config['num_feature_bins'],
                                                n_hop=int(
                                                    speech_config['stride_ms'] * speech_config['sample_rate'] // 1000),
                                                n_dft=1024,
                                                trainable_fb=speech_config['trainable_kernel']
                                                )
            else:
                self.mel_layer = Spectrogram(
                    n_hop=int(speech_config['stride_ms'] * speech_config['sample_rate'] // 1000),
                    n_dft=1024,
                    trainable_kernel=speech_config['trainable_kernel']
                )
            self.mel_layer.trainable = speech_config['trainable_kernel']

        self.ctc_classes = tf.keras.layers.Dense(vocabulary_size, name='ctc_classes')

        self.wav_info = speech_config['add_wav_info']
        if self.wav_info:
            assert speech_config['use_mel_layer'] == True, 'shold set use_mel_layer is True'

        self.dmodel = encoder.dmodel

        self.chunk_size = int(self.speech_config['sample_rate'] * self.speech_config['streaming_bucket'])
        self.decode_layer = ConformerBlock(self.dmodel, self.encoder.dropout, self.encoder.fc_factor,
                                           self.encoder.head_size,
                                           self.encoder.num_heads, name='decode_conformer_block')
        self.recognize_pb = None
        self.encoder.add_chunk_size(self.chunk_size, speech_config['num_feature_bins'],int(
                                                    speech_config['stride_ms'] * speech_config['sample_rate'] // 1000))
        self.streaming = self.speech_config['streaming']

    def _build(self, shape):  # Call on real data for building model

        batch = shape[0]
        inputs = np.random.normal(size=shape).astype(np.float32)

        targets = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9]] * batch)

        self([inputs, targets], training=True)

    def save_seperate(self, path_to_dir: str):
        self.encoder.save(os.path.join(path_to_dir, "encoder"))
        self.predict_net.save(os.path.join(path_to_dir, "prediction"))
        self.joint_net.save(os.path.join(path_to_dir, "joint"))

    def summary(self, line_length=None, **kwargs):
        self.encoder.summary(line_length=line_length, **kwargs)
        self.predict_net.summary(line_length=line_length, **kwargs)
        self.joint_net.summary(line_length=line_length, **kwargs)
        super(Transducer, self).summary(line_length=line_length, **kwargs)

        # @tf.function(experimental_relax_shapes=True)

    @tf.function(
        experimental_relax_shapes=True,
        input_signature=[
            tf.TensorSpec([None, None, 1], dtype=tf.float32),
        ]
    )
    def extract_feature(self, inputs):

        features = inputs
        if self.mel_layer is not None:
            if self.wav_info:
                wav = features
                features = self.mel_layer(features)

            else:
                features = self.mel_layer(features)

        if self.wav_info:
            enc = self.encoder.inference([features, wav], training=False)
        else:
            enc = self.encoder.inference(features, training=False)


        return enc

    def initial_states(self, inputs):

        decoder_states = self.predict_net.get_initial_state(inputs)

        return  decoder_states, tf.constant([self.text_featurizer.start])

    def call(self, inputs, training=False):
        features, predicted = inputs


        if self.mel_layer is not None:
            if self.wav_info:
                wav = features

                features = self.mel_layer(features)

            else:
                features = self.mel_layer(features)

        if self.wav_info:
            enc = self.encoder([features, wav], training=training)
        else:
            enc = self.encoder(features, training=training)
        enc=self.decode_layer(enc,training=training)

        pred, _ = self.predict_net(predicted, training=training)
        outputs = self.joint_net([enc, pred], training=training)
        ctc_outputs = self.ctc_classes(enc, training=training)

        return outputs, ctc_outputs

    def eval_inference(self, inputs, training=False):
        features, predicted = inputs

        if self.mel_layer is not None:
            if self.wav_info:
                wav = features
                features = self.mel_layer(features)
            else:
                features = self.mel_layer(features)

            # print(inputs.shape)
        if self.wav_info:
            enc = self.encoder([features, wav], training=training)
        else:
            enc = self.encoder(features, training=training)
        enc=self.decode_layer(enc,training=training)
        b_i = tf.constant(0, dtype=tf.int32)
        self.initial_states(enc)
        B = tf.shape(enc)[0]

        decoded = tf.constant([], dtype=tf.int32)

        def _cond(b_i, B, features, decoded):
            return tf.less(b_i, B)

        def _body(b_i, B, features, decoded):
            states=self.predict_net.get_initial_state(tf.expand_dims(enc[b_i], axis=0))
            decode_ = tf.constant([0], dtype=tf.int32)
            yseq = self.perform_greedy(tf.expand_dims(enc[b_i], axis=0),states,decode_,
                                       tf.constant(0, dtype=tf.int32))

            yseq = tf.concat([yseq, tf.constant([[self.text_featurizer.stop]], tf.int32)], axis=-1)
            decoded = tf.concat([decoded, yseq[0]], axis=0)
            return b_i + 1, B, features, decoded

        _, _, _, decoded = tf.while_loop(
            _cond,
            _body,
            loop_vars=(b_i, B, features, decoded),
            shape_invariants=(
                tf.TensorShape([]),
                tf.TensorShape([]),
                get_shape_invariants(features),
                tf.TensorShape([None])
            )
        )
        return decoded

    def add_featurizers(self,
                        text_featurizer: TextFeaturizer):

        self.text_featurizer = text_featurizer

    def return_pb_function(self, shape):
        @tf.function(experimental_relax_shapes=True
                     # input_signature=[
                     #     tf.TensorSpec(shape, dtype=tf.float32),  # features
                     #     tf.TensorSpec([None, 1], dtype=tf.int32),  # features
                     # ]
                     )
        def recognize_pb(features, decoded, states):
            features = tf.cast(features, tf.float32)
            yseq, states = self.perform_greedy(features, decoded, states)
            return yseq, states

        self.recognize_pb = recognize_pb

    @tf.function(experimental_relax_shapes=True,
                 input_signature=[
                     tf.TensorSpec([None,None,256], dtype=tf.float32),  # features
                     [[tf.TensorSpec([None, None],tf.float32), tf.TensorSpec([None, None],tf.float32)]],
                     tf.TensorSpec([None,], dtype=tf.int32),
                     tf.TensorSpec((), dtype=tf.int32),
                 ]
                 )
    def perform_greedy(self,
                       features,
                       states,
                       decoded,
                       start_B,
                       ):

        enc=self.decode_layer(features,training=False)


        h = states

        enc = tf.squeeze(enc, axis=0)  # [T, E]

        T = tf.cast(tf.shape(enc)[0], dtype=tf.int32)

        i = start_B

        def _cond(enc, i, decoded, h_, T):
            return tf.less(i, T)

        def _body(enc, i, decoded, h_, T):

            hi = tf.reshape(enc[i], [1, 1, -1])  # [1, 1, E]
            y, h_2 = self.predict_net(
                inputs=tf.reshape(decoded[-1], [1, 1]),  # [1, 1]
                p_memory_states=h_,
                training=False
            )
            # print(h_2)
            # y = y[:, -1:]
            # [1, 1, P], [1, P], [1, P]
            # [1, 1, E] + [1, 1, P] => [1, 1, 1, V]
            ytu = tf.nn.log_softmax(self.joint_net([hi, y], training=False))
            ytu = tf.squeeze(ytu, axis=None)  # [1, 1, 1, V] => [V]
            n_predict = tf.argmax(ytu, axis=-1, output_type=tf.int32)  # => argmax []
            n_predict = tf.reshape(n_predict, [1])

            def return_no_blank():
                return [tf.concat([decoded, n_predict], axis=0), h_2]

            decoded, h_ = tf.cond(
                n_predict != self.text_featurizer.blank and n_predict != 0,
                true_fn=return_no_blank,
                false_fn=lambda: [decoded, h_]
            )

            return enc, i + 1, decoded, h_, T

        _, _, decoded, h, _ = tf.while_loop(
            _cond,
            _body,
            loop_vars=(enc, i, decoded, h, T),
            shape_invariants=(
                tf.TensorShape([None, None]),
                tf.TensorShape([]),

                tf.TensorShape([None]),
                [[tf.TensorShape([None, None]), tf.TensorShape([None, None])]] * self.num_lstms,

                tf.TensorShape([])
            )
        )

        return decoded, h,T


    def get_config(self):
        if self.mel_layer is not None:
            conf = self.mel_layer.get_config()
            conf.update(self.encoder.get_config())
        else:
            conf = self.encoder.get_config()
        conf.update(self.decode_layer.get_config())
        conf.update(self.predict_net.get_config())
        conf.update(self.joint_net.get_config())
        conf.update(self.ctc_classes.get_config())
        conf.update(self.ctc_attention.get_config())
        return conf
Example #10
0
    def __init__(self,
                 encoder: tf.keras.Model,
                 vocabulary_size: int,
                 embed_dim: int = 512,
                 embed_dropout: float = 0,
                 num_lstms: int = 1,
                 lstm_units: int = 320,
                 joint_dim: int = 1024,
                 name="transducer", speech_config=dict,
                 **kwargs):
        super(Transducer, self).__init__(name=name, **kwargs)
        self.encoder = encoder
        self.num_lstms = num_lstms
        self.predict_net = TransducerPrediction(
            vocabulary_size=vocabulary_size,
            embed_dim=embed_dim,
            embed_dropout=embed_dropout,
            num_lstms=num_lstms,
            lstm_units=lstm_units,
            name=f"{name}_prediction"
        )
        self.joint_net = TransducerJoint(
            vocabulary_size=vocabulary_size,
            joint_dim=joint_dim,
            name=f"{name}_joint"
        )

        self.speech_config = speech_config
        self.mel_layer = None
        if speech_config['use_mel_layer']:
            if speech_config['mel_layer_type'] == 'Melspectrogram':
                self.mel_layer = Melspectrogram(sr=speech_config['sample_rate'],
                                                n_mels=speech_config['num_feature_bins'],
                                                n_hop=int(
                                                    speech_config['stride_ms'] * speech_config['sample_rate'] // 1000),
                                                n_dft=1024,
                                                trainable_fb=speech_config['trainable_kernel']
                                                )
            else:
                self.mel_layer = Spectrogram(
                    n_hop=int(speech_config['stride_ms'] * speech_config['sample_rate'] // 1000),
                    n_dft=1024,
                    trainable_kernel=speech_config['trainable_kernel']
                )
            self.mel_layer.trainable = speech_config['trainable_kernel']

        self.ctc_classes = tf.keras.layers.Dense(vocabulary_size, name='ctc_classes')

        self.wav_info = speech_config['add_wav_info']
        if self.wav_info:
            assert speech_config['use_mel_layer'] == True, 'shold set use_mel_layer is True'

        self.dmodel = encoder.dmodel

        self.chunk_size = int(self.speech_config['sample_rate'] * self.speech_config['streaming_bucket'])
        self.decode_layer = ConformerBlock(self.dmodel, self.encoder.dropout, self.encoder.fc_factor,
                                           self.encoder.head_size,
                                           self.encoder.num_heads, name='decode_conformer_block')
        self.recognize_pb = None
        self.encoder.add_chunk_size(self.chunk_size, speech_config['num_feature_bins'],int(
                                                    speech_config['stride_ms'] * speech_config['sample_rate'] // 1000))
        self.streaming = self.speech_config['streaming']
Example #11
0
class MultiTaskCTC(tf.keras.Model):
    def __init__(self,
                 encoder1,
                 encoder2,
                 encoder3,
                 classes1,
                 classes2,
                 classes3,
                 dmodel,
                 speech_config=dict,
                 **kwargs):
        super().__init__(self, **kwargs)
        self.encoder1 = encoder1
        self.encoder2 = encoder2
        self.encoder3 = encoder3
        self.speech_config = speech_config
        self.mel_layer = None
        if speech_config['use_mel_layer']:
            if speech_config['mel_layer_type'] == 'Melspectrogram':
                self.mel_layer = Melspectrogram(
                    sr=speech_config['sample_rate'],
                    n_mels=speech_config['num_feature_bins'],
                    n_hop=int(speech_config['stride_ms'] *
                              speech_config['sample_rate'] // 1000),
                    n_dft=1024,
                    trainable_fb=speech_config['trainable_kernel'])
            else:
                self.mel_layer = Spectrogram(
                    n_hop=int(speech_config['stride_ms'] *
                              speech_config['sample_rate'] // 1000),
                    n_dft=1024,
                    trainable_kernel=speech_config['trainable_kernel'])
            self.mel_layer.trainable = speech_config['trainable_kernel']
        self.wav_info = speech_config['add_wav_info']
        if self.wav_info:
            assert speech_config[
                'use_mel_layer'] == True, 'shold set use_mel_layer is True'
        self.fc1 = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(
            units=classes1, activation="linear", use_bias=True),
                                                   name="fully_connected1")

        self.fc2 = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(
            units=classes2, activation="linear", use_bias=True),
                                                   name="fully_connected2")

        self.fc3 = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(
            units=classes3, activation="linear", use_bias=True),
                                                   name="fully_connected3")

        self.fc_to_project_1 = tf.keras.layers.Dense(
            dmodel, name='word_prob_projector')
        self.fc_to_project_2 = tf.keras.layers.Dense(
            dmodel, name='phone_prob_projector')
        self.fc_to_project_3 = tf.keras.layers.Dense(dmodel,
                                                     name='py_prob_projector')
        self.fc_final_class = tf.keras.layers.Conv1D(classes3,
                                                     32,
                                                     padding='same',
                                                     name="cnn_final_class")

    def setup_window(self, win_front, win_back):
        """Call only for inference."""
        self.use_window_mask = True
        self.win_front = win_front
        self.win_back = win_back

    def setup_maximum_iterations(self, maximum_iterations):
        """Call only for inference."""
        self.maximum_iterations = maximum_iterations

    def _build(self, sample_shape):
        features = tf.random.normal(shape=sample_shape)
        self(features, training=False)

    def summary(self, line_length=None, **kwargs):
        # self.encoder.summary(line_length=line_length, **kwargs)
        super(MultiTaskCTC, self).summary(line_length, **kwargs)

    def add_featurizers(
        self,
        text_featurizer,
    ):

        self.text_featurizer = text_featurizer

        # @tf.function(experimental_relax_shapes=True)
    def encoder1_enc(self, inputs, training=False):
        if self.mel_layer is not None:
            if self.wav_info:
                wav = inputs
                inputs = self.mel_layer(inputs)
            else:
                inputs = self.mel_layer(inputs)
            # print(inputs.shape)
        if self.wav_info:
            enc_outputs = self.encoder1([inputs, wav], training=training)
        else:
            enc_outputs = self.encoder1(inputs, training=training)
        outputs = self.fc1(enc_outputs[-1], training=training)
        for i in range(12, 15):
            outputs += self.fc1(enc_outputs[i], training=training)
        return enc_outputs[-1], outputs

    def encoder2_enc(self, inputs, enc1, training=False):
        if self.mel_layer is not None:
            if self.wav_info:
                wav = inputs
                inputs = self.mel_layer(inputs)
            else:
                inputs = self.mel_layer(inputs)
            # print(inputs.shape)
        if self.wav_info:
            enc_outputs = self.encoder2([inputs, wav, enc1], training=training)
        else:
            enc_outputs = self.encoder2([inputs, enc1], training=training)
        outputs = self.fc2(enc_outputs[-1], training=training)
        for i in range(12, 15):
            outputs += self.fc2(enc_outputs[i], training=training)
        return enc_outputs[-1], outputs

    def encoder3_enc(self, inputs, enc2, training=False):
        if self.mel_layer is not None:
            if self.wav_info:
                wav = inputs
                inputs = self.mel_layer(inputs)
            else:
                inputs = self.mel_layer(inputs)
            # print(inputs.shape)
        if self.wav_info:
            enc_outputs = self.encoder3([inputs, wav, enc2], training=training)
        else:
            enc_outputs = self.encoder3([inputs, enc2], training=training)
        outputs = self.fc3(enc_outputs[-1], training=training)
        for i in range(12, 15):
            outputs += self.fc3(enc_outputs[i], training=training)
        return enc_outputs[-1], outputs

    def call(self, inputs, training=False, **kwargs):
        enc1, outputs1 = self.encoder1_enc(inputs, training)
        enc2, outputs2 = self.encoder2_enc(inputs, enc1, training)
        enc3, outputs3 = self.encoder3_enc(inputs, enc2, training)
        outputs1_ = self.fc_to_project_1(outputs1)
        outputs2_ = self.fc_to_project_2(outputs2)
        outputs3_ = self.fc_to_project_3(outputs3)
        output = outputs1_ + outputs2_ + outputs3_ + enc1 + enc2 + enc3
        outputs = self.fc_final_class(output)
        outputs += outputs3
        return outputs1, outputs2, outputs3, outputs

    def return_pb_function(self, shape, beam=False):
        @tf.function(experimental_relax_shapes=True,
                     input_signature=[
                         tf.TensorSpec(shape, dtype=tf.float32),
                         tf.TensorSpec([None, 1], dtype=tf.int32),
                     ])
        def recognize_tflite(features, length):
            logits = self.call(features, training=False)[-1]

            probs = tf.nn.softmax(logits)
            decoded = tf.keras.backend.ctc_decode(y_pred=probs,
                                                  input_length=tf.squeeze(
                                                      length, -1),
                                                  greedy=True)[0][0]
            return [decoded]

        @tf.function(experimental_relax_shapes=True,
                     input_signature=[
                         tf.TensorSpec(shape, dtype=tf.float32),
                         tf.TensorSpec([None, 1], dtype=tf.int32),
                     ])
        def recognize_beam_tflite(features, length):
            logits = self.call(features, training=False)[-1]

            probs = tf.nn.softmax(logits)
            decoded = tf.keras.backend.ctc_decode(
                y_pred=probs,
                input_length=tf.squeeze(length, -1),
                greedy=False,
                beam_width=self.text_featurizer.decoder_config["beam_width"]
            )[0][0]
            return [decoded]

        self.recognize_pb = recognize_tflite if not beam else recognize_beam_tflite

    def get_config(self):
        if self.mel_layer is not None:
            config = self.mel_layer.get_config()
            config.update(self.encoder.get_config())
        else:
            config = self.encoder.get_config()
        config.update(self.fc.get_config())
        return config
Example #12
0
    def __init__(self,
                 encoder,
                 classes1,
                 classes2,
                 classes3,
                 config,
                 training,
                 enable_tflite_convertible=False,
                 speech_config=dict,
                 **kwargs):
        super().__init__(self, **kwargs)
        self.encoder = encoder
        self.speech_config = speech_config
        self.mel_layer = None
        if speech_config['use_mel_layer']:
            if speech_config['mel_layer_type'] == 'Melspectrogram':
                self.mel_layer = Melspectrogram(
                    sr=speech_config['sample_rate'],
                    n_mels=speech_config['num_feature_bins'],
                    n_hop=int(speech_config['stride_ms'] *
                              speech_config['sample_rate'] // 1000),
                    n_dft=1024,
                    trainable_fb=speech_config['trainable_kernel'])
            else:
                self.mel_layer = Spectrogram(
                    n_hop=int(speech_config['stride_ms'] *
                              speech_config['sample_rate'] // 1000),
                    n_dft=1024,
                    trainable_kernel=speech_config['trainable_kernel'])
        self.fc1 = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(
            units=classes1, activation="linear", use_bias=True),
                                                   name="fully_connected1")

        self.fc2 = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(
            units=classes2, activation="linear", use_bias=True),
                                                   name="fully_connected2")

        self.fc3 = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(
            units=classes3, activation="linear", use_bias=True),
                                                   name="fully_connected3")
        self.fc_final = tf.keras.layers.TimeDistributed(
            tf.keras.layers.Dense(units=config.n_classes,
                                  activation="linear",
                                  use_bias=True),
            name="fully_connected4")
        self.decoder_cell = DecoderCell(
            config,
            training=training,
            name="decoder_cell",
            enable_tflite_convertible=enable_tflite_convertible)
        self.decoder = LASDecoder(
            self.decoder_cell,
            TrainingSampler(config)
            if training is True else TestingSampler(config),
            enable_tflite_convertible=enable_tflite_convertible)
        self.decoder_project = tf.keras.layers.Dense(config.decoder_lstm_units)
        self.token_project = tf.keras.Sequential([
            ConformerBlock(config.decoder_lstm_units,
                           dropout=config.dropout,
                           fc_factor=config.fc_factor,
                           head_size=config.head_size,
                           num_heads=config.num_heads,
                           kernel_size=config.kernel_size,
                           name='block%d' % i)
            for i in range(config.n_lstm_decoder + 1)
        ])
        self.config = config
        self.use_window_mask = False
        self.maximum_iterations = 1000 if training else 50
        self.enable_tflite_convertible = enable_tflite_convertible