예제 #1
0
파일: model.py 프로젝트: yyht/TensorflowASR
 def conformer_model(self, training):
     from AMmodel.conformer import ConformerTransducer, ConformerCTC, ConformerLAS
     self.model_config.update(
         {'vocabulary_size': self.text_feature.num_classes})
     if self.model_config['name'] == 'ConformerTransducer':
         self.model_config.pop('LAS_decoder')
         self.model_config.pop('enable_tflite_convertible')
         self.model_config.update({'speech_config': self.speech_config})
         self.model = ConformerTransducer(**self.model_config)
     elif self.model_config['name'] == 'ConformerCTC':
         self.model_config.update({'speech_config': self.speech_config})
         self.model = ConformerCTC(**self.model_config)
     elif self.model_config['name'] == 'ConformerLAS':
         self.config['model_config']['LAS_decoder'].update(
             {'n_classes': self.text_feature.num_classes})
         self.config['model_config']['LAS_decoder'].update(
             {'startid': self.text_feature.start})
         self.model = ConformerLAS(
             self.config['model_config'],
             training=training,
             enable_tflite_convertible=self.config['model_config']
             ['enable_tflite_convertible'],
             speech_config=self.speech_config)
     else:
         raise ('not in supported model list')
예제 #2
0
 def espnet_model(self, training):
     from AMmodel.espnet import ESPNetCTC, ESPNetLAS, ESPNetTransducer
     self.config['Transducer_decoder'].update(
         {'vocabulary_size': self.text_feature.num_classes})
     if self.model_config['name'] == 'ESPNetTransducer':
         self.model = ESPNetTransducer(self.config)
     elif self.model_config['name'] == 'ESPNetCTC':
         self.model = ESPNetCTC(self.model_config,
                                self.text_feature.num_classes)
     elif self.model_config['name'] == 'ESPNetLAS':
         self.config['LAS_decoder'].update(
             {'n_classes': self.text_feature.num_classes})
         self.config['LAS_decoder'].update(
             {'startid': self.text_feature.start})
         self.model = ESPNetLAS(self.config,
                                training=training,
                                enable_tflite_convertible=self.
                                config['enable_tflite_convertible'])
     else:
         raise ('not in supported model list')
예제 #3
0
 def tdnn_model(self, training):
     from AMmodel.tdnn import TimeDelayNNTransducer
     self.model_config.update(
         {'vocabulary_size': self.text_feature.num_classes})
     if self.model_config['name'] == 'TimeDelayNNTransducer':
         self.model_config.pop('LAS_decoder')
         self.model_config.pop('enable_tflite_convertible')
         self.model_config.update({'speech_config': self.speech_config})
         self.model = TimeDelayNNTransducer(**self.model_config)
     else:
         raise ValueError('not in supported model list')
예제 #4
0
    def multi_task_model(self, training):
        from AMmodel.MultiConformer import ConformerMultiTaskCTC
        token1_feature = TextFeaturizer(self.config['decoder1_config'])
        token2_feature = TextFeaturizer(self.config['decoder2_config'])
        token3_feature = TextFeaturizer(self.config['decoder3_config'])

        self.model_config.update({
            'classes1': token1_feature.num_classes,
            'classes2': token2_feature.num_classes,
            'classes3': token3_feature.num_classes,
        })

        self.model = ConformerMultiTaskCTC(self.model_config,
                                           training=training,
                                           speech_config=self.speech_config)
예제 #5
0
    def multi_task_model(self, training):
        from AMmodel.MultiConformer import ConformerMultiTaskLAS
        token1_feature = TextFeaturizer(self.config['decoder1_config'])
        token2_feature = TextFeaturizer(self.config['decoder2_config'])
        token3_feature = TextFeaturizer(self.config['decoder3_config'])
        token4_feature = TextFeaturizer(self.config['decoder4_config'])

        self.model_config.update({
            'classes1': token1_feature.num_classes,
            'classes2': token2_feature.num_classes,
            'classes3': token3_feature.num_classes,
        })
        self.model_config['LAS_decoder'].update(
            {'n_classes': token4_feature.num_classes})
        self.model_config['LAS_decoder'].update(
            {'startid': token4_feature.start})
        self.model = ConformerMultiTaskLAS(
            self.model_config,
            training=training,
            enable_tflite_convertible=self.
            model_config['enable_tflite_convertible'])
예제 #6
0
class AM():
    def __init__(self, config):
        self.config = config
        self.update_model_type()
        self.speech_config = self.config['speech_config']
        try:
            self.text_config = self.config['decoder_config']
        except:
            self.text_config = self.config['decoder4_config']
        self.model_config = self.config['model_config']
        self.text_feature = TextFeaturizer(self.text_config)
        self.speech_feature = SpeechFeaturizer(self.speech_config)

        self.init_steps = None

    def update_model_type(self):
        if 'CTC' in self.config['model_config']['name']:
            self.config['decoder_config'].update({'model_type': 'CTC'})
            self.model_type = 'CTC'
        elif 'Multi' in self.config['model_config']['name']:
            self.config['decoder1_config'].update({'model_type': 'CTC'})
            self.config['decoder2_config'].update({'model_type': 'CTC'})
            self.config['decoder3_config'].update({'model_type': 'CTC'})
            self.config['decoder4_config'].update({'model_type': 'LAS'})
            self.config['decoder_config'].update({'model_type': 'LAS'})
            self.model_type = 'MultiTask'
        elif 'LAS' in self.config['model_config']['name']:
            self.config['decoder_config'].update({'model_type': 'LAS'})
            self.model_type = 'LAS'
        else:
            self.config['decoder_config'].update({'model_type': 'Transducer'})
            self.model_type = 'Transducer'

    def conformer_model(self, training):
        from AMmodel.conformer import ConformerTransducer, ConformerCTC, ConformerLAS
        self.model_config.update(
            {'vocabulary_size': self.text_feature.num_classes})
        if self.model_config['name'] == 'ConformerTransducer':
            self.model_config.pop('LAS_decoder')
            self.model_config.pop('enable_tflite_convertible')
            self.model = ConformerTransducer(**self.model_config)
        elif self.model_config['name'] == 'ConformerCTC':
            self.model = ConformerCTC(**self.model_config)
        elif self.model_config['name'] == 'ConformerLAS':
            self.config['model_config']['LAS_decoder'].update(
                {'n_classes': self.text_feature.num_classes})
            self.config['model_config']['LAS_decoder'].update(
                {'startid': self.text_feature.start})
            self.model = ConformerLAS(
                self.config['model_config'],
                training=training,
                enable_tflite_convertible=self.config['model_config']
                ['enable_tflite_convertible'])
        else:
            raise ('not in supported model list')

    def ds2_model(self, training):
        from AMmodel.deepspeech2 import DeepSpeech2CTC, DeepSpeech2LAS, DeepSpeech2Transducer

        f, c = self.speech_feature.compute_feature_dim()
        input_shape = [None, f, c]
        self.model_config.update({'input_shape': input_shape})
        if self.model_config['name'] == 'DeepSpeech2Transducer':
            self.model_config.pop('LAS_decoder')
            self.model_config.pop('enable_tflite_convertible')
            self.model = DeepSpeech2Transducer(input_shape, self.model_config)
        elif self.model_config['name'] == 'DeepSpeech2CTC':
            self.model = DeepSpeech2CTC(input_shape, self.model_config,
                                        self.text_feature.num_classes)
        elif self.model_config['name'] == 'DeepSpeech2LAS':
            self.model_config['LAS_decoder'].update(
                {'n_classes': self.text_feature.num_classes})
            self.model_config['LAS_decoder'].update(
                {'startid': self.text_feature.start})
            self.model = DeepSpeech2LAS(
                self.model_config,
                input_shape,
                training=training,
                enable_tflite_convertible=self.
                model_config['enable_tflite_convertible'])
        else:
            raise ('not in supported model list')

    def multi_task_model(self, training):
        from AMmodel.MultiConformer import ConformerMultiTaskLAS
        token1_feature = TextFeaturizer(self.config['decoder1_config'])
        token2_feature = TextFeaturizer(self.config['decoder2_config'])
        token3_feature = TextFeaturizer(self.config['decoder3_config'])
        token4_feature = TextFeaturizer(self.config['decoder4_config'])

        self.model_config.update({
            'classes1': token1_feature.num_classes,
            'classes2': token2_feature.num_classes,
            'classes3': token3_feature.num_classes,
        })
        self.model_config['LAS_decoder'].update(
            {'n_classes': token4_feature.num_classes})
        self.model_config['LAS_decoder'].update(
            {'startid': token4_feature.start})
        self.model = ConformerMultiTaskLAS(
            self.model_config,
            training=training,
            enable_tflite_convertible=self.
            model_config['enable_tflite_convertible'])

    def espnet_model(self, training):
        from AMmodel.espnet import ESPNetCTC, ESPNetLAS, ESPNetTransducer
        self.config['Transducer_decoder'].update(
            {'vocabulary_size': self.text_feature.num_classes})
        if self.model_config['name'] == 'ESPNetTransducer':
            self.model = ESPNetTransducer(self.config)
        elif self.model_config['name'] == 'ESPNetCTC':
            self.model = ESPNetCTC(self.model_config,
                                   self.text_feature.num_classes)
        elif self.model_config['name'] == 'ESPNetLAS':
            self.config['LAS_decoder'].update(
                {'n_classes': self.text_feature.num_classes})
            self.config['LAS_decoder'].update(
                {'startid': self.text_feature.start})
            self.model = ESPNetLAS(self.config,
                                   training=training,
                                   enable_tflite_convertible=self.
                                   config['enable_tflite_convertible'])
        else:
            raise ('not in supported model list')

    def load_model(self, training=True):

        if 'ESPNet' in self.model_config['name']:
            self.espnet_model(training)
        elif 'Multi' in self.model_config['name']:
            self.multi_task_model(training)

        elif 'Conformer' in self.model_config['name']:
            self.conformer_model(training)
        else:
            self.ds2_model(training)
        self.model.add_featurizers(self.text_feature)
        f, c = self.speech_feature.compute_feature_dim()

        try:
            if not training:
                if self.text_config['model_type'] != 'LAS':
                    self.model._build([3, 80, f, c])
                    self.model._build([2, 80, f, c])
                    self.model._build([1, 80, f, c])
                    self.model.return_pb_function(f, c)

                else:
                    self.model._build([3, 80, f, c], training)
                    self.model._build([1, 80, f, c], training)
                    self.model._build([2, 80, f, c], training)
                    self.model.return_pb_function(f, c)
                self.load_checkpoint(self.config)

        except:
            print('am loading model failed.')

    def convert_to_pb(self, export_path):
        import tensorflow as tf
        f, c = self.speech_feature.compute_feature_dim()
        self.model.return_pb_function(f, c)

        concrete_func = self.model.recognize_pb.get_concrete_function()
        tf.saved_model.save(self.model, export_path, signatures=concrete_func)

    def decode_result(self, word):
        de = []
        for i in word:
            if i != self.text_feature.stop:
                de.append(self.text_feature.index_to_token[int(i)])
            else:
                break
        return de

    def predict(self, fp):
        if '.pcm' in fp:
            data = np.fromfile(fp, 'int16')
            data = np.array(data, 'float32')
            data /= 32768
        else:
            data = self.speech_feature.load_wav(fp)

        mel = self.speech_feature.extract(data)
        mel = np.expand_dims(mel, 0)
        input_length = np.array(
            [[mel.shape[1] // self.model.time_reduction_factor]], 'int32')
        result = self.model.recognize_pb(mel, input_length)[0]
        return result

    def load_checkpoint(self, config):
        """Load checkpoint."""

        self.checkpoint_dir = os.path.join(
            config['learning_config']['running_config']["outdir"],
            "checkpoints")
        files = os.listdir(self.checkpoint_dir)
        files.sort(key=lambda x: int(x.split('_')[-1].replace('.h5', '')))
        self.model.load_weights(os.path.join(self.checkpoint_dir, files[-1]))
        self.init_steps = int(files[-1].split('_')[-1].replace('.h5', ''))