Exemplo n.º 1
0
class AM_DataLoader():

    def __init__(self, config_dict,training=True):
        self.speech_config = config_dict['speech_config']


        self.text_config = config_dict['decoder_config']
        self.augment_config = config_dict['augments_config']

        self.batch = config_dict['learning_config']['running_config']['batch_size']
        self.speech_featurizer = SpeechFeaturizer(self.speech_config)
        self.text_featurizer = TextFeaturizer(self.text_config)
        self.make_file_list(self.speech_config['train_list'] if training else self.speech_config['eval_list'],training)
        self.augment = Augmentation(self.augment_config)
        self.init_text_to_vocab()
        self.epochs = 1
        self.LAS=False
        self.steps = 0
    def load_state(self,outdir):
        try:
            self.pick_index=np.load(os.path.join(outdir,'dg_state.npy')).flatten().tolist()
            self.epochs=1+int(np.mean(self.pick_index))
        except FileNotFoundError:
            print('not found state file')
        except:
            print('load state falied,use init state')
    def save_state(self,outdir):
        np.save(os.path.join(outdir,'dg_state.npy'),np.array(self.pick_index))

    def return_data_types(self):
        if self.LAS:
            return (tf.float32, tf.float32, tf.int32, tf.int32, tf.int32,tf.float32)
        else:
            return  (tf.float32, tf.int32, tf.int32, tf.int32)
    def return_data_shape(self):
        f,c=self.speech_featurizer.compute_feature_dim()
        if self.LAS:
            return (
                tf.TensorShape([None,None,1]) if self.speech_config['use_mel_layer'] else  tf.TensorShape([None,None,f,c]),

                tf.TensorShape([None,]),
                tf.TensorShape([None,None]),
                tf.TensorShape([None,]),
                tf.TensorShape([None,None,None])
            )
        else:
            return (
                tf.TensorShape([None, None, 1]) if self.speech_config['use_mel_layer'] else tf.TensorShape(
                    [None, None, f, c]),

                tf.TensorShape([None, ]),
                tf.TensorShape([None, None]),
                tf.TensorShape([None, ])
            )
    def get_per_epoch_steps(self):
        return len(self.train_list)//self.batch
    def eval_per_epoch_steps(self):
        return len(self.test_list)//self.batch
    def init_text_to_vocab(self):
        pypinyin.load_phrases_dict({'调大': [['tiáo'], ['dà']],
                                    '调小': [['tiáo'], ['xiǎo']],
                                    '调亮': [['tiáo'], ['liàng']],
                                    '调暗': [['tiáo'], ['àn']],
                                    '肖': [['xiāo']],
                                    '英雄传': [['yīng'], ['xióng'], ['zhuàn']],
                                    '新传': [['xīn'], ['zhuàn']],
                                    '外传': [['wài'], ['zhuàn']],
                                    '正传': [['zhèng'], ['zhuàn']], '水浒传': [['shuǐ'], ['hǔ'], ['zhuàn']]
                                    })

        def text_to_vocab_func(txt):
            pins=pypinyin.pinyin(txt)
            pins=[i[0] for i in pins]
            return pins

        self.text_to_vocab = text_to_vocab_func

    def augment_data(self, wavs, label, label_length):
        if not self.augment.available():
            return None
        mels = []
        input_length = []
        label_ = []
        label_length_ = []
        wavs_ = []
        max_input = 0
        max_wav = 0
        for idx, wav in enumerate(wavs):

            data = self.augment.process(wav.flatten())
            speech_feature = self.speech_featurizer.extract(data)
            if speech_feature.shape[0] // self.speech_config['reduction_factor'] < label_length[idx]:
                continue
            max_input = max(max_input, speech_feature.shape[0])

            max_wav = max(max_wav, len(data))

            wavs_.append(data)

            mels.append(speech_feature)
            input_length.append(speech_feature.shape[0] // self.speech_config['reduction_factor'])
            label_.append(label[idx])
            label_length_.append(label_length[idx])

        for i in range(len(mels)):
            if mels[i].shape[0] < max_input:
                pad = np.ones([max_input - mels[i].shape[0], mels[i].shape[1],mels[i].shape[2]]) * mels[i].min()
                mels[i] = np.vstack((mels[i], pad))

        wavs_ = self.speech_featurizer.pad_signal(wavs_, max_wav)

        x = np.array(mels, 'float32')
        label_ = np.array(label_, 'int32')

        input_length = np.array(input_length, 'int32')
        label_length_ = np.array(label_length_, 'int32')

        wavs_ = np.array(np.expand_dims(wavs_, -1), 'float32')

        return x, wavs_, input_length, label_, label_length_

    def make_file_list(self, wav_list,training=True):
        with open(wav_list, encoding='utf-8') as f:
            data = f.readlines()
        data=[i.strip()  for i in data if i!='']
        num = len(data)
        if training:
            self.train_list = data[:int(num * 0.99)]
            self.test_list = data[int(num * 0.99):]
            np.random.shuffle(self.train_list)
            self.pick_index = [0.] * len(self.train_list)
        else:
            self.test_list=data
            self.offset=0
    def only_chinese(self, word):
        txt=''
        for ch in word:
            if '\u4e00' <= ch <= '\u9fff':
                txt+=ch
            else:
                continue

        return txt
    def eval_data_generator(self):
        sample=self.test_list[self.offset:self.offset+self.batch]
        self.offset+=self.batch
        speech_features = []
        input_length = []
        y1 = []
        label_length1 = []
        max_input = 0
        max_label1 = 0
        for i in sample:
            wp, txt = i.strip().split('\t')
            txt=txt.replace(' ','')
            try:
                data = self.speech_featurizer.load_wav(wp)
            except:
                print('{} load data failed'.format(wp))
                continue
            if len(data) < 400:
                continue
            elif len(data) > self.speech_featurizer.sample_rate *  self.speech_config['wav_max_duration']:
                print('{} duration out of wav_max_duration({})'.format(wp,self.speech_config['wav_max_duration']))
                continue
            if self.speech_config['only_chinese']:
                txt= self.only_chinese(txt)
            if self.speech_config['use_mel_layer']:
                speech_feature = data / np.abs(data).max()
                speech_feature = np.expand_dims(speech_feature, -1)
                in_len = len(speech_feature) // (
                        self.speech_config['reduction_factor'] * (self.speech_featurizer.sample_rate / 1000) *
                        self.speech_config['stride_ms'])
            else:
                speech_feature = self.speech_featurizer.extract(data)
                in_len = int(speech_feature.shape[0] // self.speech_config['reduction_factor'])
            max_input = max(max_input, speech_feature.shape[0])

            py = self.text_to_vocab(txt)
            if not self.check_valid(py, self.text_featurizer.vocab_array):
                print(' {} txt pinyin {} not all in tokens,continue'.format(txt,py))
                continue
            text_feature = self.text_featurizer.extract(py)

            if in_len < len(text_feature):
                print('{} feature length < pinyin length,continue'.format(wp))
                continue
            max_input = max(max_input, len(speech_feature))
            max_label1 = max(max_label1, len(text_feature))
            speech_features.append(speech_feature)
            input_length.append(in_len)
            y1.append(np.array(text_feature))
            label_length1.append(len(text_feature))

        if self.speech_config['use_mel_layer']:
            speech_features = self.speech_featurizer.pad_signal(speech_features, max_input)

        else:
            for i in range(len(speech_features)):

                if speech_features[i].shape[0] < max_input:
                    pad = np.ones([max_input - speech_features[i].shape[0], speech_features[i].shape[1],
                                   speech_features[i].shape[2]]) * speech_features[i].min()
                    speech_features[i] = np.vstack((speech_features[i], pad))

        for i in range(len(y1)):
            if y1[i].shape[0] < max_label1:
                pad = np.ones(max_label1 - y1[i].shape[0]) * self.text_featurizer.pad
                y1[i] = np.hstack((y1[i], pad))

        x = np.array(speech_features, 'float32')
        y1 = np.array(y1, 'int32')

        input_length = np.array(input_length, 'int32')
        label_length1 = np.array(label_length1, 'int32')

        return x, input_length, y1, label_length1
    def check_valid(self,txt,vocab_list):
        if len(txt)==0:
            return False
        for n in txt:
            if n in vocab_list:
                pass
            else:
                return False
        return True
    def GuidedAttentionMatrix(self, N, T, g=0.2):
        W = np.zeros((N, T), dtype=np.float32)
        for n in range(N):
            for t in range(T):
                W[n, t] = 1 - np.exp(-(t / float(T) - n / float(N)) ** 2 / (2 * g * g))
        return W

    def guided_attention(self, input_length, targets_length, inputs_shape, mel_target_shape):
        att_targets = []
        for i, j in zip(input_length, targets_length):
            i = int(i)
            step = int(j)
            pad = np.ones([inputs_shape, mel_target_shape]) * -1.
            pad[i:, :step] = 1
            att_target = self.GuidedAttentionMatrix(i, step, 0.2)
            pad[:att_target.shape[0], :att_target.shape[1]] = att_target
            att_targets.append(pad)
        att_targets = np.array(att_targets)

        return att_targets.astype('float32')
    def generate(self, train=True):

        if train:
            batch=self.batch if self.augment.available() else self.batch*2
            indexs = np.argsort(self.pick_index)[:batch]
            indexs = random.sample(indexs.tolist(), batch//2)
            sample = [self.train_list[i] for i in indexs]
            for i in indexs:
                self.pick_index[int(i)] += 1
            self.epochs =1+ int(np.mean(self.pick_index))
        else:
            sample = random.sample(self.test_list, self.batch)
        speech_features = []
        input_length = []
        y1 = []
        label_length1 = []

        max_input = 0
        max_label1 = 0
        for i in sample:
            wp, txt = i.strip().split('\t')
            try:
                data = self.speech_featurizer.load_wav(wp)
            except:
                print('{} load data failed'.format(wp))
                continue
            if len(data) < 400:
                continue
            elif len(data) > self.speech_featurizer.sample_rate * self.speech_config['wav_max_duration']:
                print('{} duration out of wav_max_duration({})'.format(wp, self.speech_config['wav_max_duration']))
                continue
            if self.speech_config['only_chinese']:
                txt= self.only_chinese(txt)
            if self.speech_config['use_mel_layer']:
                speech_feature = data / np.abs(data).max()
                speech_feature = np.expand_dims(speech_feature, -1)
                in_len = len(speech_feature) // (
                        self.speech_config['reduction_factor'] * (self.speech_featurizer.sample_rate / 1000) *
                        self.speech_config['stride_ms'])
            else:
                speech_feature = self.speech_featurizer.extract(data)
                in_len = int(speech_feature.shape[0] // self.speech_config['reduction_factor'])


            py = self.text_to_vocab(txt)
            if not self.check_valid(py,self.text_featurizer.vocab_array):
                print(' {} txt pinyin {} not all in tokens,continue'.format(txt, py))
                continue
            text_feature = self.text_featurizer.extract(py)

            if in_len < len(text_feature):
                print('{} feature length < pinyin length,continue'.format(wp))
                continue
            max_input = max(max_input,len(speech_feature))
            max_label1 = max(max_label1, len(text_feature))
            speech_features.append(speech_feature)
            input_length.append(in_len)
            y1.append(np.array(text_feature))
            label_length1.append(len(text_feature))
        if train and self.augment.available():
            for i in sample:
                wp, txt = i.strip().split('\t')
                try:
                    data = self.speech_featurizer.load_wav(wp)
                except:
                    print('load data failed')
                    continue
                if len(data) < 400:
                    continue
                elif len(data) > self.speech_featurizer.sample_rate *  self.speech_config['wav_max_duration']:
                    continue
                data = self.augment.process(data)
                if self.speech_config['only_chinese']:
                    txt = self.only_chinese(txt)
                if self.speech_config['use_mel_layer']:
                    speech_feature = data / np.abs(data).max()
                    speech_feature = np.expand_dims(speech_feature, -1)
                    in_len = len(speech_feature) // (
                            self.speech_config['reduction_factor'] * (self.speech_featurizer.sample_rate / 1000) *
                            self.speech_config['stride_ms'])
                else:
                    speech_feature = self.speech_featurizer.extract(data)
                    in_len = int(speech_feature.shape[0] // self.speech_config['reduction_factor'])

                py = self.text_to_vocab(txt)
                if not self.check_valid(py, self.text_featurizer.vocab_array):
                    continue

                text_feature = self.text_featurizer.extract(py)


                if in_len < len(text_feature):
                    continue
                max_input = max(max_input, len(speech_feature))
                max_label1 = max(max_label1, len(text_feature))
                speech_features.append(speech_feature)

                input_length.append(in_len)
                y1.append(np.array(text_feature))
                label_length1.append(len(text_feature))

        if self.speech_config['use_mel_layer']:
            speech_features = self.speech_featurizer.pad_signal(speech_features, max_input)

        else:
            for i in range(len(speech_features)):

                if speech_features[i].shape[0] < max_input:
                    pad = np.ones([max_input - speech_features[i].shape[0], speech_features[i].shape[1],
                                   speech_features[i].shape[2]]) * speech_features[i].min()
                    speech_features[i] = np.vstack((speech_features[i], pad))

        for i in range(len(y1)):
            if y1[i].shape[0] < max_label1:
                pad = np.ones(max_label1 - y1[i].shape[0])*self.text_featurizer.pad
                y1[i] = np.hstack((y1[i], pad))

        x = np.array(speech_features, 'float32')
        y1 = np.array(y1, 'int32')

        input_length = np.array(input_length, 'int32')
        label_length1 = np.array(label_length1, 'int32')

        return x, input_length, y1, label_length1
    def generator(self,train=True):
        while 1:
            x,  input_length, labels, label_length=self.generate(train)
            if x.shape[0]==0:
                print('load data length zero,continue')
                continue
            if self.LAS:
                guide_matrix = self.guided_attention(input_length, label_length, np.max(input_length),
                                                     label_length.max())
                yield x, input_length, labels, label_length,guide_matrix
            else:
                yield x,  input_length, labels, label_length
Exemplo n.º 2
0
class AM():
    def __init__(self, config):
        self.config = config
        self.update_model_type()
        self.speech_config = self.config['speech_config']
        try:
            self.text_config = self.config['decoder_config']
        except:
            self.text_config = self.config['decoder4_config']
        self.model_config = self.config['model_config']
        self.text_feature = TextFeaturizer(self.text_config)
        self.speech_feature = SpeechFeaturizer(self.speech_config)

        self.init_steps = None

    def update_model_type(self):
        if 'CTC' in self.config['model_config']['name']:
            self.config['decoder_config'].update({'model_type': 'CTC'})
            self.model_type = 'CTC'
        elif 'Multi' in self.config['model_config']['name']:
            self.config['decoder1_config'].update({'model_type': 'CTC'})
            self.config['decoder2_config'].update({'model_type': 'CTC'})
            self.config['decoder3_config'].update({'model_type': 'CTC'})
            self.config['decoder4_config'].update({'model_type': 'LAS'})
            self.config['decoder_config'].update({'model_type': 'LAS'})
            self.model_type = 'MultiTask'
        elif 'LAS' in self.config['model_config']['name']:
            self.config['decoder_config'].update({'model_type': 'LAS'})
            self.model_type = 'LAS'
        else:
            self.config['decoder_config'].update({'model_type': 'Transducer'})
            self.model_type = 'Transducer'

    def conformer_model(self, training):
        from AMmodel.conformer import ConformerTransducer, ConformerCTC, ConformerLAS
        self.model_config.update(
            {'vocabulary_size': self.text_feature.num_classes})
        if self.model_config['name'] == 'ConformerTransducer':
            self.model_config.pop('LAS_decoder')
            self.model_config.pop('enable_tflite_convertible')
            self.model = ConformerTransducer(**self.model_config)
        elif self.model_config['name'] == 'ConformerCTC':
            self.model = ConformerCTC(**self.model_config)
        elif self.model_config['name'] == 'ConformerLAS':
            self.config['model_config']['LAS_decoder'].update(
                {'n_classes': self.text_feature.num_classes})
            self.config['model_config']['LAS_decoder'].update(
                {'startid': self.text_feature.start})
            self.model = ConformerLAS(
                self.config['model_config'],
                training=training,
                enable_tflite_convertible=self.config['model_config']
                ['enable_tflite_convertible'])
        else:
            raise ('not in supported model list')

    def ds2_model(self, training):
        from AMmodel.deepspeech2 import DeepSpeech2CTC, DeepSpeech2LAS, DeepSpeech2Transducer

        f, c = self.speech_feature.compute_feature_dim()
        input_shape = [None, f, c]
        self.model_config.update({'input_shape': input_shape})
        if self.model_config['name'] == 'DeepSpeech2Transducer':
            self.model_config.pop('LAS_decoder')
            self.model_config.pop('enable_tflite_convertible')
            self.model = DeepSpeech2Transducer(input_shape, self.model_config)
        elif self.model_config['name'] == 'DeepSpeech2CTC':
            self.model = DeepSpeech2CTC(input_shape, self.model_config,
                                        self.text_feature.num_classes)
        elif self.model_config['name'] == 'DeepSpeech2LAS':
            self.model_config['LAS_decoder'].update(
                {'n_classes': self.text_feature.num_classes})
            self.model_config['LAS_decoder'].update(
                {'startid': self.text_feature.start})
            self.model = DeepSpeech2LAS(
                self.model_config,
                input_shape,
                training=training,
                enable_tflite_convertible=self.
                model_config['enable_tflite_convertible'])
        else:
            raise ('not in supported model list')

    def multi_task_model(self, training):
        from AMmodel.MultiConformer import ConformerMultiTaskLAS
        token1_feature = TextFeaturizer(self.config['decoder1_config'])
        token2_feature = TextFeaturizer(self.config['decoder2_config'])
        token3_feature = TextFeaturizer(self.config['decoder3_config'])
        token4_feature = TextFeaturizer(self.config['decoder4_config'])

        self.model_config.update({
            'classes1': token1_feature.num_classes,
            'classes2': token2_feature.num_classes,
            'classes3': token3_feature.num_classes,
        })
        self.model_config['LAS_decoder'].update(
            {'n_classes': token4_feature.num_classes})
        self.model_config['LAS_decoder'].update(
            {'startid': token4_feature.start})
        self.model = ConformerMultiTaskLAS(
            self.model_config,
            training=training,
            enable_tflite_convertible=self.
            model_config['enable_tflite_convertible'])

    def espnet_model(self, training):
        from AMmodel.espnet import ESPNetCTC, ESPNetLAS, ESPNetTransducer
        self.config['Transducer_decoder'].update(
            {'vocabulary_size': self.text_feature.num_classes})
        if self.model_config['name'] == 'ESPNetTransducer':
            self.model = ESPNetTransducer(self.config)
        elif self.model_config['name'] == 'ESPNetCTC':
            self.model = ESPNetCTC(self.model_config,
                                   self.text_feature.num_classes)
        elif self.model_config['name'] == 'ESPNetLAS':
            self.config['LAS_decoder'].update(
                {'n_classes': self.text_feature.num_classes})
            self.config['LAS_decoder'].update(
                {'startid': self.text_feature.start})
            self.model = ESPNetLAS(self.config,
                                   training=training,
                                   enable_tflite_convertible=self.
                                   config['enable_tflite_convertible'])
        else:
            raise ('not in supported model list')

    def load_model(self, training=True):

        if 'ESPNet' in self.model_config['name']:
            self.espnet_model(training)
        elif 'Multi' in self.model_config['name']:
            self.multi_task_model(training)

        elif 'Conformer' in self.model_config['name']:
            self.conformer_model(training)
        else:
            self.ds2_model(training)
        self.model.add_featurizers(self.text_feature)
        f, c = self.speech_feature.compute_feature_dim()

        try:
            if not training:
                if self.text_config['model_type'] != 'LAS':
                    self.model._build([3, 80, f, c])
                    self.model._build([2, 80, f, c])
                    self.model._build([1, 80, f, c])
                    self.model.return_pb_function(f, c)

                else:
                    self.model._build([3, 80, f, c], training)
                    self.model._build([1, 80, f, c], training)
                    self.model._build([2, 80, f, c], training)
                    self.model.return_pb_function(f, c)
                self.load_checkpoint(self.config)

        except:
            print('am loading model failed.')

    def convert_to_pb(self, export_path):
        import tensorflow as tf
        f, c = self.speech_feature.compute_feature_dim()
        self.model.return_pb_function(f, c)

        concrete_func = self.model.recognize_pb.get_concrete_function()
        tf.saved_model.save(self.model, export_path, signatures=concrete_func)

    def decode_result(self, word):
        de = []
        for i in word:
            if i != self.text_feature.stop:
                de.append(self.text_feature.index_to_token[int(i)])
            else:
                break
        return de

    def predict(self, fp):
        if '.pcm' in fp:
            data = np.fromfile(fp, 'int16')
            data = np.array(data, 'float32')
            data /= 32768
        else:
            data = self.speech_feature.load_wav(fp)

        mel = self.speech_feature.extract(data)
        mel = np.expand_dims(mel, 0)
        input_length = np.array(
            [[mel.shape[1] // self.model.time_reduction_factor]], 'int32')
        result = self.model.recognize_pb(mel, input_length)[0]
        return result

    def load_checkpoint(self, config):
        """Load checkpoint."""

        self.checkpoint_dir = os.path.join(
            config['learning_config']['running_config']["outdir"],
            "checkpoints")
        files = os.listdir(self.checkpoint_dir)
        files.sort(key=lambda x: int(x.split('_')[-1].replace('.h5', '')))
        self.model.load_weights(os.path.join(self.checkpoint_dir, files[-1]))
        self.init_steps = int(files[-1].split('_')[-1].replace('.h5', ''))
class MultiTask_DataLoader():

    def __init__(self, config_dict,training=True):
        self.speech_config = config_dict['speech_config']
        self.text1_config = config_dict['decoder1_config']
        self.text2_config = config_dict['decoder2_config']
        self.text3_config = config_dict['decoder3_config']
        self.text4_config = config_dict['decoder4_config']
        self.augment_config = config_dict['augments_config']
        self.batch = config_dict['learning_config']['running_config']['batch_size']
        self.speech_featurizer = SpeechFeaturizer(self.speech_config)
        self.token1_featurizer = TextFeaturizer(self.text1_config)
        self.token2_featurizer = TextFeaturizer(self.text2_config)
        self.token3_featurizer = TextFeaturizer(self.text3_config)
        self.token4_featurizer = TextFeaturizer(self.text4_config)
        self.make_file_list(self.speech_config['train_list'] if training else self.speech_config['eval_list'],training)
        self.make_maps(config_dict)
        self.augment = Augmentation(self.augment_config)
        self.epochs = 1
        self.LAS=True
        self.steps = 0

        self.init_bert(config_dict)

    def load_state(self,outdir):
        try:
            self.pick_index=np.load(os.path.join(outdir,'dg_state.npy')).flatten().tolist()
            self.epochs=1+int(np.mean(self.pick_index))
        except FileNotFoundError:
            print('not found state file')
        except:
            print('load state falied,use init state')
    def save_state(self,outdir):
        np.save(os.path.join(outdir,'dg_state.npy'),np.array(self.pick_index))
    def load_bert(self, config, checkpoint):
        model = load_trained_model_from_checkpoint(config, checkpoint, trainable=False, seq_len=None)
        return model

    def init_bert(self,config):
        bert_config = config['bert']['config_json']
        bert_checkpoint = config['bert']['bert_ckpt']
        bert_vocab = config['bert']['bert_vocab']
        bert_vocabs = load_vocabulary(bert_vocab)
        self.bert_token = Tokenizer(bert_vocabs)
        self.bert = self.load_bert(bert_config, bert_checkpoint)

    def bert_decode(self, x):
        tokens, segs = [], []

        for i in x:
            t, s = self.bert_token.encode(''.join(i))
            tokens.append(t)
            segs.append(s)
        return tokens, segs
    def get_bert_feature(self, bert_t, bert_s):
        f = []
        for t, s in zip(bert_t, bert_s):
            t = np.expand_dims(np.array(t), 0)
            s = np.expand_dims(np.array(s), 0)
            feature = self.bert.predict([t, s])
            f.append(feature[0])
        return f[0][1:]
    def return_data_types(self):

        return (tf.float32, tf.float32, tf.float32,tf.int32, tf.int32,tf.int32,tf.int32,tf.int32,tf.int32,tf.int32,tf.int32, tf.int32,tf.float32)

    def return_data_shape(self):
        f,c=self.speech_featurizer.compute_feature_dim()

        return (
            tf.TensorShape([None,None,f,c]),
            tf.TensorShape([None,None,1]),
            tf.TensorShape([None, None, 768]),
            tf.TensorShape([None,]),
            tf.TensorShape([None,None]),
            tf.TensorShape([None,]),
            tf.TensorShape([None, None]),
            tf.TensorShape([None, ]),
            tf.TensorShape([None, None]),
            tf.TensorShape([None, ]),
            tf.TensorShape([None, None]),
            tf.TensorShape([None, ]),
            tf.TensorShape([None,None,None])
        )

    def get_per_epoch_steps(self):
        return len(self.train_list)//self.batch
    def eval_per_epoch_steps(self):
        return len(self.test_list)//self.batch
    def make_maps(self,config):
        with open(config['map_path']['pinyin'],encoding='utf-8') as f:
            data=f.readlines()
        data=[i.strip() for i in data if i!='']
        self.py_map={}
        for line in data:
            key,py=line.strip().split('\t')
            self.py_map[key]=py
            if len(py.split(' '))>1:
                for i,j in zip(list(key),py.split(' ')):
                    self.py_map[i]=j
        with open(config['map_path']['phone'],encoding='utf-8') as f:
            data=f.readlines()
        data=[i.strip() for i in data if i!='']
        self.phone_map={}
        phone_map={}
        for line in data:
            key,py=line.strip().split('\t')
            phone_map[key]=py
        for key in self.py_map.keys():
            key_py=self.py_map[key]
            if len(key)>1:
                phone=[]
                for n in key_py.split(' '):
                    phone+=[phone_map[n]]
                self.phone_map[key]=' '.join(phone)
            else:
                self.phone_map[key]=phone_map[self.py_map[key]]
    def map(self,txt):
        cut=lcut(txt)
        pys=[]
        phones=[]
        words=[]
        for i in cut:
            word=i.word
            if word in self.py_map.keys():
                py=self.py_map[word]
                phone=self.phone_map[word]
                pys+=py.split(' ')
                phones+=phone.split(' ')
                words+=list(''.join(py.split(' ')))
            else:
                for j in word:
                    pys+=[self.py_map[j]]
                    phones+=self.phone_map[j].split(' ')
                    words+=list(''.join(self.py_map[j]))
        return pys,phones,words

    def augment_data(self, wavs, label, label_length):
        if not self.augment.available():
            return None
        mels = []
        input_length = []
        label_ = []
        label_length_ = []
        wavs_ = []
        max_input = 0
        max_wav = 0
        for idx, wav in enumerate(wavs):

            data = self.augment.process(wav.flatten())
            speech_feature = self.speech_featurizer.extract(data)
            if speech_feature.shape[0] // self.speech_config['reduction_factor'] < label_length[idx]:
                continue
            max_input = max(max_input, speech_feature.shape[0])

            max_wav = max(max_wav, len(data))

            wavs_.append(data)

            mels.append(speech_feature)
            input_length.append(speech_feature.shape[0] // self.speech_config['reduction_factor'])
            label_.append(label[idx])
            label_length_.append(label_length[idx])

        for i in range(len(mels)):
            if mels[i].shape[0] < max_input:
                pad = np.ones([max_input - mels[i].shape[0], mels[i].shape[1],mels[i].shape[2]]) * mels[i].min()
                mels[i] = np.vstack((mels[i], pad))

        wavs_ = self.speech_featurizer.pad_signal(wavs_, max_wav)

        x = np.array(mels, 'float32')
        label_ = np.array(label_, 'int32')

        input_length = np.array(input_length, 'int32')
        label_length_ = np.array(label_length_, 'int32')

        wavs_ = np.array(np.expand_dims(wavs_, -1), 'float32')

        return x, wavs_, input_length, label_, label_length_

    def make_file_list(self, wav_list,training=True):
        with open(wav_list, encoding='utf-8') as f:
            data = f.readlines()
        data=[i.strip()  for i in data if i!='']
        num = len(data)
        if training:
            self.train_list = data[:int(num * 0.99)]
            self.test_list = data[int(num * 0.99):]
            np.random.shuffle(self.train_list)
            self.pick_index = [0.] * len(self.train_list)
        else:
            self.test_list=data
            self.offset=0
    def only_chinese(self, word):

        for ch in word:
            if '\u4e00' <= ch <= '\u9fff':
                pass
            else:
                return False

        return True
    def eval_data_generator(self):
        sample=self.test_list[self.offset:self.offset+self.batch]
        self.offset+=self.batch
        mels = []
        input_length = []

        words_label = []
        words_label_length = []

        phone_label = []
        phone_label_length = []

        py_label = []
        py_label_length = []

        txt_label = []
        txt_label_length = []
        
        bert_features=[]
        wavs = []

        max_wav = 0
        max_input = 0
        max_label_words = 0
        max_label_phone = 0
        max_label_py = 0
        max_label_txt = 0
        for i in sample:
            wp, txt = i.strip().split('\t')
            try:
                data = self.speech_featurizer.load_wav(wp)
            except:
                print('load data failed')
                continue
            if len(data) < 400:
                continue
            elif len(data) > self.speech_featurizer.sample_rate * 7:
                continue

            if not self.only_chinese(txt):
                continue

            speech_feature = self.speech_featurizer.extract(data)
            max_input = max(max_input, speech_feature.shape[0])

            py,phone,word = self.map(txt)
            if len(py) == 0:
                continue
            e_bert_t, e_bert_s = self.bert_decode([txt])
            bert_feature = self.get_bert_feature(e_bert_t, e_bert_s)

            word_text_feature = self.token1_featurizer.extract(word)
            phone_text_feature = self.token2_featurizer.extract(phone)
            py_text_feature = self.token3_featurizer.extract(py)
            txt_text_feature = self.token4_featurizer.extract(list(txt))
            max_label_words = max(max_label_words, len(word_text_feature))
            max_label_phone = max(max_label_phone, len(phone_text_feature))
            max_label_py = max(max_label_py, len(py_text_feature))
            max_label_txt = max(max_label_txt, len(txt_text_feature))
        
            max_wav = max(max_wav, len(data))
            if speech_feature.shape[0] / self.speech_config['reduction_factor'] < len(py_text_feature):
                continue
            mels.append(speech_feature)
            wavs.append(data)
            input_length.append(speech_feature.shape[0] // self.speech_config['reduction_factor'])
            words_label.append(np.array(word_text_feature))
            words_label_length.append(len(word_text_feature))

            phone_label.append(np.array(phone_text_feature))
            phone_label_length.append(len(phone_text_feature))

            py_label.append(np.array(py_text_feature))
            py_label_length.append(len(py_text_feature))

            txt_label.append(np.array(txt_text_feature))
            txt_label_length.append(len(txt_text_feature))
            bert_features.append(bert_feature)

        for i in range(len(mels)):
            if mels[i].shape[0] < max_input:
                pad = np.ones([max_input - mels[i].shape[0], mels[i].shape[1], mels[i].shape[2]]) * mels[i].min()
                mels[i] = np.vstack((mels[i], pad))

        for i in range(len(bert_features)):

            if bert_features[i].shape[0] < max_label_txt:
                pading = np.ones([max_label_txt - len(bert_features[i]), 768]) * -10.
                bert_features[i] = np.vstack((bert_features[i], pading))


        wavs = self.speech_featurizer.pad_signal(wavs, max_wav)
        words_label = self.pad(words_label, max_label_words)
        phone_label = self.pad(phone_label, max_label_phone)
        py_label = self.pad(py_label, max_label_py)
        txt_label = self.pad(txt_label, max_label_txt)

        x = np.array(mels, 'float32')
        bert_features = np.array(bert_features, 'float32')
        words_label = np.array(words_label, 'int32')
        phone_label = np.array(phone_label, 'int32')
        py_label = np.array(py_label, 'int32')
        txt_label = np.array(txt_label, 'int32')

        input_length = np.array(input_length, 'int32')
        words_label_length = np.array(words_label_length, 'int32')
        phone_label_length = np.array(phone_label_length, 'int32')
        py_label_length = np.array(py_label_length, 'int32')
        txt_label_length = np.array(txt_label_length, 'int32')

        wavs = np.array(np.expand_dims(wavs, -1), 'float32')

        return x, wavs, bert_features,input_length, words_label, words_label_length, phone_label, phone_label_length, py_label, py_label_length, txt_label, txt_label_length
    def pad(self,words_label,max_label_words):
        for i in range(len(words_label)):
            if words_label[i].shape[0] < max_label_words:
                pad = np.ones(max_label_words - words_label[i].shape[0]) * self.token1_featurizer.pad
                words_label[i] = np.hstack((words_label[i], pad))
        return words_label
    def GuidedAttention(self, N, T, g=0.2):
        W = np.zeros((N, T), dtype=np.float32)
        for n in range(N):
            for t in range(T):
                W[n, t] = 1 - np.exp(-(t / float(T) - n / float(N)) ** 2 / (2 * g * g))
        return W

    def guided_attention(self, input_length, targets_length, inputs_shape, mel_target_shape):
        att_targets = []
        for i, j in zip(input_length, targets_length):
            i = int(i)
            step = int(j)
            pad = np.ones([inputs_shape, mel_target_shape]) * -1.
            pad[i:, :step] = 1
            att_target = self.GuidedAttention(i, step, 0.2)
            pad[:att_target.shape[0], :att_target.shape[1]] = att_target
            att_targets.append(pad)
        att_targets = np.array(att_targets)

        return att_targets.astype('float32')
    def generate(self, train=True):

        if train:
            batch=self.batch if self.augment.available() else self.batch*2
            indexs = np.argsort(self.pick_index)[:batch]
            indexs = random.sample(indexs.tolist(), batch//2)
            sample = [self.train_list[i] for i in indexs]
            for i in indexs:
                self.pick_index[int(i)] += 1
            self.epochs = 1+int(np.mean(self.pick_index))
        else:
            sample = random.sample(self.test_list, self.batch)

        mels = []
        input_length = []

        words_label = []
        words_label_length = []

        phone_label = []
        phone_label_length = []

        py_label = []
        py_label_length = []

        txt_label = []
        txt_label_length = []

        bert_features = []
        wavs = []

        max_wav = 0
        max_input = 0
        max_label_words = 0
        max_label_phone = 0
        max_label_py = 0
        max_label_txt = 0
        for i in sample:
            wp, txt = i.strip().split('\t')
            try:
                data = self.speech_featurizer.load_wav(wp)
            except:
                print('load data failed')
                continue
            if len(data) < 400:
                continue
            elif len(data) > self.speech_featurizer.sample_rate * 7:
                continue

            if not self.only_chinese(txt):
                continue

            speech_feature = self.speech_featurizer.extract(data)


            py, phone, word = self.map(txt)
            if len(py) == 0 or len(phone)==0 or len(word)==0:
                continue
            e_bert_t, e_bert_s = self.bert_decode([txt])
            bert_feature = self.get_bert_feature(e_bert_t, e_bert_s)

            word_text_feature = self.token1_featurizer.extract(word)
            phone_text_feature = self.token2_featurizer.extract(phone)
            py_text_feature = self.token3_featurizer.extract(py)
            txt_text_feature = self.token4_featurizer.extract(list(txt))

            if speech_feature.shape[0] / self.speech_config['reduction_factor'] < len(py_text_feature) or \
                    speech_feature.shape[0] / self.speech_config['reduction_factor'] < len(word_text_feature) or \
                    speech_feature.shape[0] / self.speech_config['reduction_factor'] < len(phone_text_feature):
                continue
            max_input = max(max_input, speech_feature.shape[0])
            max_label_words = max(max_label_words, len(word_text_feature))
            max_label_phone = max(max_label_phone, len(phone_text_feature))
            max_label_py = max(max_label_py, len(py_text_feature))
            max_label_txt = max(max_label_txt, len(txt_text_feature))

            max_wav = max(max_wav, len(data))
            mels.append(speech_feature)
            wavs.append(data)
            input_length.append(speech_feature.shape[0] // self.speech_config['reduction_factor'])
            words_label.append(np.array(word_text_feature))
            words_label_length.append(len(word_text_feature))

            phone_label.append(np.array(phone_text_feature))
            phone_label_length.append(len(phone_text_feature))

            py_label.append(np.array(py_text_feature))
            py_label_length.append(len(py_text_feature))

            txt_label.append(np.array(txt_text_feature))
            txt_label_length.append(len(txt_text_feature))
            bert_features.append(bert_feature)


        if train and self.augment.available():
            for i in sample:
                wp, txt = i.strip().split('\t')
                try:
                    data = self.speech_featurizer.load_wav(wp)
                except:
                    print('load data failed')
                    continue
                if len(data) < 400:
                    continue
                elif len(data) > self.speech_featurizer.sample_rate * 7:
                    continue

                if not self.only_chinese(txt):
                    continue
                data=self.augment.process(data)
                speech_feature = self.speech_featurizer.extract(data)


                py, phone, word = self.map(txt)
                if len(py) == 0 or len(phone) == 0 or len(word) == 0:
                    continue
                e_bert_t, e_bert_s = self.bert_decode([txt])
                bert_feature = self.get_bert_feature(e_bert_t, e_bert_s)

                word_text_feature = self.token1_featurizer.extract(word)
                phone_text_feature = self.token2_featurizer.extract(phone)
                py_text_feature = self.token3_featurizer.extract(py)
                txt_text_feature = self.token4_featurizer.extract(list(txt))



                if speech_feature.shape[0] / self.speech_config['reduction_factor'] < len(py_text_feature) or \
                        speech_feature.shape[0] / self.speech_config['reduction_factor'] < len(word_text_feature) or \
                        speech_feature.shape[0] / self.speech_config['reduction_factor'] < len(phone_text_feature):
                    continue
                max_input = max(max_input, speech_feature.shape[0])
                max_wav = max(max_wav, len(data))
                max_label_words = max(max_label_words, len(word_text_feature))
                max_label_phone = max(max_label_phone, len(phone_text_feature))
                max_label_py = max(max_label_py, len(py_text_feature))
                max_label_txt = max(max_label_txt, len(txt_text_feature))
                mels.append(speech_feature)
                wavs.append(data)
                input_length.append(speech_feature.shape[0] // self.speech_config['reduction_factor'])
                words_label.append(np.array(word_text_feature))
                words_label_length.append(len(word_text_feature))

                phone_label.append(np.array(phone_text_feature))
                phone_label_length.append(len(phone_text_feature))

                py_label.append(np.array(py_text_feature))
                py_label_length.append(len(py_text_feature))

                txt_label.append(np.array(txt_text_feature))
                txt_label_length.append(len(txt_text_feature))
                bert_features.append(bert_feature)

        for i in range(len(mels)):
            if mels[i].shape[0] < max_input:
                pad = np.ones([max_input - mels[i].shape[0], mels[i].shape[1], mels[i].shape[2]]) * mels[i].min()
                mels[i] = np.vstack((mels[i], pad))
        for i in range(len(bert_features)):
            if bert_features[i].shape[0]<max_label_txt:
                pading = np.ones([max_label_txt - len(bert_features[i]), 768]) * -10.
                bert_features[i] = np.vstack((bert_features[i], pading))

        wavs = self.speech_featurizer.pad_signal(wavs, max_wav)
        words_label = self.pad(words_label, max_label_words)
        phone_label = self.pad(phone_label, max_label_phone)
        py_label = self.pad(py_label, max_label_py)
        txt_label = self.pad(txt_label, max_label_txt)

        x = np.array(mels, 'float32')
        bert_features = np.array(bert_features, 'float32')
        words_label = np.array(words_label, 'int32')
        phone_label = np.array(phone_label, 'int32')
        py_label = np.array(py_label, 'int32')
        txt_label = np.array(txt_label, 'int32')

        input_length = np.array(input_length, 'int32')
        words_label_length = np.array(words_label_length, 'int32')
        phone_label_length = np.array(phone_label_length, 'int32')
        py_label_length = np.array(py_label_length, 'int32')
        txt_label_length = np.array(txt_label_length, 'int32')

        wavs = np.array(np.expand_dims(wavs, -1), 'float32')

        return x, wavs, bert_features,input_length, words_label, words_label_length, phone_label, phone_label_length, py_label, py_label_length, txt_label, txt_label_length
    def generator(self,train=True):
        while 1:
            x, wavs,bert_feature, input_length, words_label, words_label_length, phone_label, phone_label_length, py_label, py_label_length, txt_label, txt_label_length=self.generate(train)

            guide_matrix = self.guided_attention(input_length, txt_label_length, np.max(input_length),
                                                 txt_label_length.max())
            yield x, wavs, bert_feature,input_length, words_label, words_label_length, phone_label, phone_label_length, py_label, py_label_length, txt_label, txt_label_length,guide_matrix
Exemplo n.º 4
0
class MultiTask_DataLoader():
    def __init__(self, config_dict, training=True):
        self.speech_config = config_dict['speech_config']
        self.text1_config = config_dict['decoder1_config']
        self.text2_config = config_dict['decoder2_config']
        self.text3_config = config_dict['decoder3_config']
        self.augment_config = config_dict['augments_config']
        self.batch = config_dict['learning_config']['running_config'][
            'batch_size']
        self.speech_featurizer = SpeechFeaturizer(self.speech_config)
        self.token1_featurizer = TextFeaturizer(self.text1_config)
        self.token2_featurizer = TextFeaturizer(self.text2_config)
        self.token3_featurizer = TextFeaturizer(self.text3_config)
        self.make_file_list(
            self.speech_config['train_list']
            if training else self.speech_config['eval_list'], training)
        self.make_maps(config_dict)
        self.augment = Augmentation(self.augment_config)
        self.epochs = 1
        self.steps = 0

    def load_state(self, outdir):
        try:

            dg_state = np.load(os.path.join(outdir, 'dg_state.npz'))

            self.epochs = int(dg_state['epoch'])
            self.train_offset = int(dg_state['train_offset'])
            train_list = dg_state['train_list'].tolist()
            if len(train_list) != len(self.train_list):
                logging.info(
                    'history train list not equal train list ,data loader use init state'
                )
                self.epochs = 0
                self.train_offset = 0
        except FileNotFoundError:
            logging.info('not found state file,init state')
        except:
            logging.info('load state falied,use init state')

    def save_state(self, outdir):
        np.savez(os.path.join(outdir, 'dg_state.npz'),
                 epoch=self.epochs,
                 train_offset=self.train_offset,
                 train_list=self.train_list)

    def return_data_types(self):

        return (tf.float32, tf.int32, tf.int32, tf.int32, tf.int32, tf.int32,
                tf.int32, tf.int32)

    def return_data_shape(self):
        f, c = self.speech_featurizer.compute_feature_dim()

        return (
            tf.TensorShape([None, None, 1])
            if self.speech_config['use_mel_layer'] else tf.TensorShape(
                [None, None, f, c]),
            tf.TensorShape([
                None,
            ]),
            tf.TensorShape([None, None]),
            tf.TensorShape([
                None,
            ]),
            tf.TensorShape([None, None]),
            tf.TensorShape([
                None,
            ]),
            tf.TensorShape([None, None]),
            tf.TensorShape([
                None,
            ]),
        )

    def get_per_epoch_steps(self):
        return len(self.train_list) // self.batch

    def eval_per_epoch_steps(self):
        return len(self.test_list) // self.batch

    def make_maps(self, config):
        with open(config['map_path']['phone'], encoding='utf-8') as f:
            data = f.readlines()
        data = [i.strip() for i in data if i != '']
        self.phone_map = {}
        phone_map = {}
        for line in data:
            try:
                key, phone = line.strip().split('\t')
            except:
                continue
            phone_map[key] = phone.split(' ')
        self.phone_map = phone_map

    def map(self, txt):
        pys = pypinyin.pinyin(txt, 8, neutral_tone_with_five=True)

        pys = [i[0] for i in pys]
        phones = []

        for i in pys:
            phones += self.phone_map[i]
        words = ''.join(pys)
        words = list(words)
        return pys, phones, words

    def make_file_list(self, wav_list, training=True):
        with open(wav_list, encoding='utf-8') as f:
            data = f.readlines()
        data = [i.strip() for i in data if i != '']
        num = len(data)
        if training:
            self.train_list = data[:int(num * 0.99)]
            self.test_list = data[int(num * 0.99):]
            np.random.shuffle(self.train_list)
            self.train_offset = 0
            self.test_offset = 0
            logging.info('train list : {} test list:{}'.format(
                len(self.train_list), len(self.test_list)))
        else:
            self.test_list = data
            self.offset = 0
            logging.info('eval list: {}'.format(len(self.test_list)))

    def only_chinese(self, word):
        txt = ''
        for ch in word:
            if '\u4e00' <= ch <= '\u9fff':
                txt += ch
            else:
                continue

        return txt

    def check_valid(self, txt, vocab_list):
        if len(txt) == 0:
            return False
        for n in txt:
            if n in vocab_list:
                pass
            else:
                return n
        return True

    def eval_data_generator(self):
        sample = self.test_list[self.offset:self.offset + self.batch]
        self.offset += self.batch
        speech_features = []
        input_length = []

        words_label = []
        words_label_length = []

        phone_label = []
        phone_label_length = []

        py_label = []
        py_label_length = []

        max_input = 0
        max_label_words = 0
        max_label_phone = 0
        max_label_py = 0

        for i in sample:
            wp, txt = i.strip().split('\t')
            try:
                data = self.speech_featurizer.load_wav(wp)
            except:
                logging.info('{} load data failed,skip'.format(wp))
                continue
            if len(data) < 400:
                continue
            elif len(
                    data
            ) > self.speech_featurizer.sample_rate * self.speech_config[
                    'wav_max_duration']:
                logging.info(
                    '{} duration out of wav_max_duration({}),skip'.format(
                        wp, self.speech_config['wav_max_duration']))
                continue
            if self.speech_config['only_chinese']:
                txt = self.only_chinese(txt)
            if self.speech_config['use_mel_layer']:
                speech_feature = data / np.abs(data).max()
                speech_feature = np.expand_dims(speech_feature, -1)
                in_len = len(speech_feature) // (
                    self.speech_config['reduction_factor'] *
                    (self.speech_featurizer.sample_rate / 1000) *
                    self.speech_config['stride_ms'])
            else:
                speech_feature = self.speech_featurizer.extract(data)
                in_len = int(speech_feature.shape[0] //
                             self.speech_config['reduction_factor'])

            py, phone, word = self.map(txt)
            if len(py) == 0:
                continue

            if not self.check_valid(word, self.token1_featurizer.vocab_array):
                logging.info(
                    ' {} txt word {} not all in tokens,continue'.format(
                        txt, py))
                continue

            if not self.check_valid(phone, self.token1_featurizer.vocab_array):
                logging.info(
                    ' {} txt phone {} not all in tokens,continue'.format(
                        txt, py))
                continue

            if not self.check_valid(py, self.token1_featurizer.vocab_array):
                logging.info(
                    ' {} txt pinyin {} not all in tokens,continue'.format(
                        txt, py))
                continue
            word_text_feature = self.token1_featurizer.extract(word)
            phone_text_feature = self.token2_featurizer.extract(phone)
            py_text_feature = self.token3_featurizer.extract(py)

            if in_len < len(word_text_feature):
                continue

            max_label_words = max(max_label_words, len(word_text_feature))
            max_label_phone = max(max_label_phone, len(phone_text_feature))
            max_label_py = max(max_label_py, len(py_text_feature))
            max_input = max(max_input, len(speech_feature))

            speech_features.append(speech_feature)
            input_length.append(in_len)
            words_label.append(np.array(word_text_feature))
            words_label_length.append(len(word_text_feature))

            phone_label.append(np.array(phone_text_feature))
            phone_label_length.append(len(phone_text_feature))

            py_label.append(np.array(py_text_feature))
            py_label_length.append(len(py_text_feature))

        if self.speech_config['use_mel_layer']:
            speech_features = self.speech_featurizer.pad_signal(
                speech_features, max_input)

        else:
            for i in range(len(speech_features)):

                if speech_features[i].shape[0] < max_input:
                    pad = np.ones([
                        max_input - speech_features[i].shape[0],
                        speech_features[i].shape[1],
                        speech_features[i].shape[2]
                    ]) * speech_features[i].min()
                    speech_features[i] = np.vstack((speech_features[i], pad))

        words_label = self.pad(words_label, max_label_words)
        phone_label = self.pad(phone_label, max_label_phone)
        py_label = self.pad(py_label, max_label_py)
        speech_features = np.array(speech_features, 'float32')
        words_label = np.array(words_label, 'int32')
        phone_label = np.array(phone_label, 'int32')
        py_label = np.array(py_label, 'int32')
        input_length = np.array(input_length, 'int32')
        words_label_length = np.array(words_label_length, 'int32')
        phone_label_length = np.array(phone_label_length, 'int32')
        py_label_length = np.array(py_label_length, 'int32')

        return speech_features, input_length, words_label, words_label_length, phone_label, phone_label_length, py_label, py_label_length

    def pad(self, words_label, max_label_words):
        for i in range(len(words_label)):
            if words_label[i].shape[0] < max_label_words:
                pad = np.ones(max_label_words - words_label[i].shape[0]
                              ) * self.token1_featurizer.pad
                words_label[i] = np.hstack((words_label[i], pad))
        return words_label

    def GuidedAttention(self, N, T, g=0.2):
        W = np.zeros((N, T), dtype=np.float32)
        for n in range(N):
            for t in range(T):
                W[n, t] = 1 - np.exp(-(t / float(T) - n / float(N))**2 /
                                     (2 * g * g))
        return W

    def guided_attention(self, input_length, targets_length, inputs_shape,
                         mel_target_shape):
        att_targets = []
        for i, j in zip(input_length, targets_length):
            i = int(i)
            step = int(j)
            pad = np.ones([inputs_shape, mel_target_shape]) * -1.
            pad[i:, :step] = 1
            att_target = self.GuidedAttention(i, step, 0.2)
            pad[:att_target.shape[0], :att_target.shape[1]] = att_target
            att_targets.append(pad)
        att_targets = np.array(att_targets)

        return att_targets.astype('float32')

    def generate(self, train=True):
        sample = []
        speech_features = []
        input_length = []

        words_label = []
        words_label_length = []

        phone_label = []
        phone_label_length = []

        py_label = []
        py_label_length = []

        max_input = 0
        max_label_words = 0
        max_label_phone = 0
        max_label_py = 0
        if train:
            batch = self.batch // 2 if self.augment.available() else self.batch
        else:
            batch = self.batch

        for i in range(batch * 10):
            if train:
                line = self.train_list[self.train_offset]
                self.train_offset += 1
                if self.train_offset > len(self.train_list) - 1:
                    self.train_offset = 0
                    np.random.shuffle(self.train_list)
                    self.epochs += 1
            else:
                line = self.test_list[self.test_offset]
                self.test_offset += 1
                if self.test_offset > len(self.test_list) - 1:
                    self.test_offset = 0

            wp, txt = line.strip().split('\t')
            try:
                data = self.speech_featurizer.load_wav(wp)
            except:
                logging.info('{} load data failed,skip'.format(wp))
                continue
            if len(data) < 400:
                continue
            elif len(
                    data
            ) > self.speech_featurizer.sample_rate * self.speech_config[
                    'wav_max_duration']:
                logging.info(
                    '{} duration out of wav_max_duration({}),skip'.format(
                        wp, self.speech_config['wav_max_duration']))
                continue
            if self.speech_config['only_chinese']:
                txt = self.only_chinese(txt)
            if self.speech_config['use_mel_layer']:
                speech_feature = data / np.abs(data).max()
                speech_feature = np.expand_dims(speech_feature, -1)
                in_len = len(speech_feature) // (
                    self.speech_config['reduction_factor'] *
                    (self.speech_featurizer.sample_rate / 1000) *
                    self.speech_config['stride_ms'])
            else:
                speech_feature = self.speech_featurizer.extract(data)
                in_len = int(speech_feature.shape[0] //
                             self.speech_config['reduction_factor'])

            py, phone, word = self.map(txt)
            if len(py) == 0:
                logging.info('py length', len(py), 'skip')
                continue

            if self.check_valid(
                    word, self.token1_featurizer.vocab_array) is not True:
                logging.info(
                    ' {} txt word {} not all in tokens,continue'.format(
                        txt,
                        self.check_valid(word,
                                         self.token1_featurizer.vocab_array)))
                continue
            #
            if self.check_valid(
                    phone, self.token2_featurizer.vocab_array) is not True:
                logging.info(
                    ' {} txt phone {} not all in tokens,continue'.format(
                        txt,
                        self.check_valid(phone,
                                         self.token2_featurizer.vocab_array)))
                continue
            #
            if self.check_valid(
                    py, self.token3_featurizer.vocab_array) is not True:
                logging.info(' {} txt py {} not all in tokens,continue'.format(
                    txt,
                    self.check_valid(py, self.token3_featurizer.vocab_array)))
                continue
            word_text_feature = self.token1_featurizer.extract(word)
            phone_text_feature = self.token2_featurizer.extract(phone)
            py_text_feature = self.token3_featurizer.extract(py)

            if in_len < len(word_text_feature):
                continue

            max_label_words = max(max_label_words, len(word_text_feature))
            max_label_phone = max(max_label_phone, len(phone_text_feature))
            max_label_py = max(max_label_py, len(py_text_feature))
            max_input = max(max_input, len(speech_feature))

            speech_features.append(speech_feature)
            input_length.append(in_len)
            words_label.append(np.array(word_text_feature))
            words_label_length.append(len(word_text_feature))

            phone_label.append(np.array(phone_text_feature))
            phone_label_length.append(len(phone_text_feature))

            py_label.append(np.array(py_text_feature))
            py_label_length.append(len(py_text_feature))
            sample.append(line)
            if len(sample) == batch:
                break
        if train and self.augment.available():
            for i in sample:
                wp, txt = i.strip().split('\t')
                try:
                    data = self.speech_featurizer.load_wav(wp)
                except:

                    continue
                if len(data) < 400:
                    continue
                elif len(
                        data
                ) > self.speech_featurizer.sample_rate * self.speech_config[
                        'wav_max_duration']:

                    continue
                data = self.augment.process(data)
                if self.speech_config['only_chinese']:
                    txt = self.only_chinese(txt)
                if self.speech_config['use_mel_layer']:
                    speech_feature = data / np.abs(data).max()
                    speech_feature = np.expand_dims(speech_feature, -1)
                    in_len = len(speech_feature) // (
                        self.speech_config['reduction_factor'] *
                        (self.speech_featurizer.sample_rate / 1000) *
                        self.speech_config['stride_ms'])
                else:
                    speech_feature = self.speech_featurizer.extract(data)
                    in_len = int(speech_feature.shape[0] //
                                 self.speech_config['reduction_factor'])

                py, phone, word = self.map(txt)
                if len(py) == 0:
                    continue

                word_text_feature = self.token1_featurizer.extract(word)
                phone_text_feature = self.token2_featurizer.extract(phone)
                py_text_feature = self.token3_featurizer.extract(py)

                if in_len < len(word_text_feature):
                    continue

                max_label_words = max(max_label_words, len(word_text_feature))
                max_label_phone = max(max_label_phone, len(phone_text_feature))
                max_label_py = max(max_label_py, len(py_text_feature))
                max_input = max(max_input, len(speech_feature))

                speech_features.append(speech_feature)
                input_length.append(in_len)
                words_label.append(np.array(word_text_feature))
                words_label_length.append(len(word_text_feature))

                phone_label.append(np.array(phone_text_feature))
                phone_label_length.append(len(phone_text_feature))

                py_label.append(np.array(py_text_feature))
                py_label_length.append(len(py_text_feature))

        if self.speech_config['use_mel_layer']:
            speech_features = self.speech_featurizer.pad_signal(
                speech_features, max_input)

        else:
            for i in range(len(speech_features)):

                if speech_features[i].shape[0] < max_input:
                    pad = np.ones([
                        max_input - speech_features[i].shape[0],
                        speech_features[i].shape[1],
                        speech_features[i].shape[2]
                    ]) * speech_features[i].min()
                    speech_features[i] = np.vstack((speech_features[i], pad))

        words_label = self.pad(words_label, max_label_words)
        phone_label = self.pad(phone_label, max_label_phone)
        py_label = self.pad(py_label, max_label_py)
        speech_features = np.array(speech_features, 'float32')
        words_label = np.array(words_label, 'int32')
        phone_label = np.array(phone_label, 'int32')
        py_label = np.array(py_label, 'int32')
        input_length = np.array(input_length, 'int32')
        words_label_length = np.array(words_label_length, 'int32')
        phone_label_length = np.array(phone_label_length, 'int32')
        py_label_length = np.array(py_label_length, 'int32')

        return speech_features, input_length, words_label, words_label_length, phone_label, phone_label_length, py_label, py_label_length

    def generator(self, train=True):
        while 1:
            speech_features, input_length, words_label, words_label_length, phone_label, phone_label_length, py_label, py_label_length = self.generate(
                train)

            yield speech_features, input_length, words_label, words_label_length, phone_label, phone_label_length, py_label, py_label_length
Exemplo n.º 5
0
        outputs+=self.out_cnn(block_outputs,training=training)
        return outputs

if __name__ == '__main__':
    from utils.user_config import UserConfig
    from utils.text_featurizers import TextFeaturizer
    from utils.speech_featurizers import SpeechFeaturizer
    import os
    import time
    os.environ['CUDA_VISIBLE_DEVICES']='1'
    config=UserConfig(r'D:\TF2-ASR\configs\am_data.yml',r'D:\TF2-ASR\configs\conformer.yml')
    config['decoder_config'].update({'model_type':'LAS'})

    Tfer=TextFeaturizer(config['decoder_config'])
    SFer=SpeechFeaturizer(config['speech_config'])
    f,c=SFer.compute_feature_dim()
    config['model_config']['LAS_decoder'].update({'n_classes': Tfer.num_classes})
    config['model_config']['LAS_decoder'].update({'startid': Tfer.start})

    ct=ConformerLAS(config['model_config'],training=False)
    # ct.add_featurizers(Tfer)
    x=tf.ones([1,300,f,c])
    length=tf.constant([300])
    out=ct._build([1,300,f,c],training=True)
    ct.inference(x,length//4)
    s=time.time()
    a=ct.inference(x,length//4)
    e=time.time()
    print(e-s,a)
    # ct.summary()
    # print(out)
Exemplo n.º 6
0
class AM_DataLoader():
    def __init__(self, config_dict, training=True):
        self.speech_config = config_dict['speech_config']

        self.text_config = config_dict['decoder_config']
        self.augment_config = config_dict['augments_config']
        self.streaming = self.speech_config['streaming']
        self.chunk = self.speech_config['sample_rate'] * self.speech_config[
            'streaming_bucket']
        self.batch = config_dict['learning_config']['running_config'][
            'batch_size']
        self.speech_featurizer = SpeechFeaturizer(self.speech_config)
        self.text_featurizer = TextFeaturizer(self.text_config)
        self.make_file_list(
            self.speech_config['train_list']
            if training else self.speech_config['eval_list'], training)
        self.augment = Augmentation(self.augment_config)
        self.init_text_to_vocab()
        self.epochs = 1
        self.LAS = False
        self.steps = 0

    def load_state(self, outdir):
        try:

            dg_state = np.load(os.path.join(outdir, 'dg_state.npz'))

            self.epochs = int(dg_state['epoch'])
            self.train_offset = int(dg_state['train_offset'])
            train_list = dg_state['train_list'].tolist()
            if len(train_list) != len(self.train_list):
                logging.info(
                    'history train list not equal new load train list ,data loader use init state'
                )
                self.epochs = 0
                self.train_offset = 0
        except FileNotFoundError:
            logging.info('not found state file,init state')
        except:
            logging.info('load state falied,use init state')

    def save_state(self, outdir):

        np.savez(os.path.join(outdir, 'dg_state.npz'),
                 epoch=self.epochs,
                 train_offset=self.train_offset,
                 train_list=self.train_list)

    def return_data_types(self):
        if self.LAS:
            return (tf.float32, tf.int32, tf.int32, tf.int32, tf.float32)
        else:
            return (tf.float32, tf.int32, tf.int32, tf.int32)

    def return_data_shape(self):
        f, c = self.speech_featurizer.compute_feature_dim()
        if self.LAS:
            return (tf.TensorShape([None, None, 1])
                    if self.speech_config['use_mel_layer'] else
                    tf.TensorShape([None, None, f, c]), tf.TensorShape([
                        None,
                    ]), tf.TensorShape([None, None]), tf.TensorShape([
                        None,
                    ]), tf.TensorShape([None, None, None]))
        else:
            return (tf.TensorShape([None, None, 1])
                    if self.speech_config['use_mel_layer'] else
                    tf.TensorShape([None, None, f, c]), tf.TensorShape([
                        None,
                    ]), tf.TensorShape([None, None]), tf.TensorShape([
                        None,
                    ]))

    def get_per_epoch_steps(self):
        return len(self.train_list) // self.batch

    def eval_per_epoch_steps(self):
        return len(self.test_list) // self.batch

    def init_text_to_vocab(self):
        pypinyin.load_phrases_dict({
            '调大': [['tiáo'], ['dà']],
            '调小': [['tiáo'], ['xiǎo']],
            '调亮': [['tiáo'], ['liàng']],
            '调暗': [['tiáo'], ['àn']],
            '肖': [['xiāo']],
            '英雄传': [['yīng'], ['xióng'], ['zhuàn']],
            '新传': [['xīn'], ['zhuàn']],
            '外传': [['wài'], ['zhuàn']],
            '正传': [['zhèng'], ['zhuàn']],
            '水浒传': [['shuǐ'], ['hǔ'], ['zhuàn']]
        })

        def text_to_vocab_func(txt):
            pins = pypinyin.pinyin(txt)
            pins = [i[0] for i in pins]
            return pins

        self.text_to_vocab = text_to_vocab_func

    def make_file_list(self, wav_list, training=True):
        with open(wav_list, encoding='utf-8') as f:
            data = f.readlines()
        data = [i.strip() for i in data if i != '']
        num = len(data)
        if training:
            self.train_list = data[:int(num * 0.99)]
            self.test_list = data[int(num * 0.99):]
            np.random.shuffle(self.train_list)
            self.train_offset = 0
            self.test_offset = 0
            logging.info('load train list {} test list{}'.format(
                len(self.train_list), len(self.test_list)))
        else:
            self.test_list = data
            self.offset = 0

    def only_chinese(self, word):
        txt = ''
        for ch in word:
            if '\u4e00' <= ch <= '\u9fff':
                txt += ch
            else:
                continue

        return txt

    def eval_data_generator(self):
        sample = self.test_list[self.offset:self.offset + self.batch]
        self.offset += self.batch
        speech_features = []
        input_length = []
        y1 = []
        label_length1 = []
        max_input = 0
        max_label1 = 0
        for i in sample:
            wp, txt = i.strip().split('\t')
            txt = txt.replace(' ', '')
            try:
                data = self.speech_featurizer.load_wav(wp)
            except:
                logging.info('{} load data failed,skip'.format(wp))
                continue
            if len(data) < 400:
                logging.info('{} wav too short < 25ms,skip'.format(wp))
                continue
            elif len(
                    data
            ) > self.speech_featurizer.sample_rate * self.speech_config[
                    'wav_max_duration']:
                logging.info(
                    '{} duration out of wav_max_duration({}) ,skip'.format(
                        wp, self.speech_config['wav_max_duration']))
                continue
            if self.speech_config['only_chinese']:
                txt = self.only_chinese(txt)
            if self.speech_config['use_mel_layer']:
                if not self.streaming:
                    speech_feature = data / np.abs(data).max()
                    speech_feature = np.expand_dims(speech_feature, -1)
                    in_len = len(speech_feature) // (
                        self.speech_config['reduction_factor'] *
                        (self.speech_featurizer.sample_rate / 1000) *
                        self.speech_config['stride_ms'])
                else:
                    speech_feature = data
                    speech_feature = np.expand_dims(speech_feature, -1)
                    reduce = self.speech_config['reduction_factor'] * (
                        self.speech_featurizer.sample_rate /
                        1000) * self.speech_config['stride_ms']
                    in_len = len(speech_feature) // self.chunk
                    if len(speech_feature) % self.chunk != 0:
                        in_len += 1
                    chunk_times = self.chunk // reduce
                    if self.chunk % reduce != 0:
                        chunk_times += 1
                    in_len *= chunk_times

            else:
                speech_feature = self.speech_featurizer.extract(data)
                in_len = int(speech_feature.shape[0] //
                             self.speech_config['reduction_factor'])
            max_input = max(max_input, speech_feature.shape[0])

            py = self.text_to_vocab(txt)
            if self.check_valid(py,
                                self.text_featurizer.vocab_array) is not True:
                logging.info(' {} txt pinyin {} not all in tokens,skip'.format(
                    txt, self.check_valid(py,
                                          self.text_featurizer.vocab_array)))
                continue
            text_feature = self.text_featurizer.extract(py)

            if in_len < len(text_feature):
                logging.info(
                    '{} feature length < pinyin length,skip'.format(wp))
                continue
            max_input = max(max_input, len(speech_feature))
            max_label1 = max(max_label1, len(text_feature))
            speech_features.append(speech_feature)
            input_length.append(in_len)
            y1.append(np.array(text_feature))
            label_length1.append(len(text_feature))

        if self.speech_config['use_mel_layer']:
            if self.streaming:
                max_input = max_input // self.chunk * self.chunk + self.chunk
            speech_features = self.speech_featurizer.pad_signal(
                speech_features, max_input)

        else:
            for i in range(len(speech_features)):

                if speech_features[i].shape[0] < max_input:
                    pad = np.ones([
                        max_input - speech_features[i].shape[0],
                        speech_features[i].shape[1],
                        speech_features[i].shape[2]
                    ]) * speech_features[i].min()
                    speech_features[i] = np.vstack((speech_features[i], pad))

        for i in range(len(y1)):
            if y1[i].shape[0] < max_label1:
                pad = np.ones(max_label1 -
                              y1[i].shape[0]) * self.text_featurizer.pad
                y1[i] = np.hstack((y1[i], pad))

        x = np.array(speech_features, 'float32')
        y1 = np.array(y1, 'int32')

        input_length = np.array(input_length, 'int32')
        label_length1 = np.array(label_length1, 'int32')

        return x, input_length, y1, label_length1

    def check_valid(self, txt, vocab_list):
        if len(txt) == 0:
            return False
        for n in txt:
            if n in vocab_list:
                pass
            else:
                return n
        return True

    def GuidedAttentionMatrix(self, N, T, g=0.2):
        W = np.zeros((N, T), dtype=np.float32)
        for n in range(N):
            for t in range(T):
                W[n, t] = 1 - np.exp(-(t / float(T) - n / float(N))**2 /
                                     (2 * g * g))
        return W

    def guided_attention(self, input_length, targets_length, inputs_shape,
                         mel_target_shape):
        att_targets = []
        for i, j in zip(input_length, targets_length):
            i = int(i)
            step = int(j)
            pad = np.ones([inputs_shape, mel_target_shape]) * -1.
            pad[i:, :step] = 1
            att_target = self.GuidedAttentionMatrix(i, step, 0.2)
            pad[:att_target.shape[0], :att_target.shape[1]] = att_target
            att_targets.append(pad)
        att_targets = np.array(att_targets)

        return att_targets.astype('float32')

    def generate(self, train=True):

        sample = []
        speech_features = []
        input_length = []
        y1 = []
        label_length1 = []

        max_input = 0
        max_label1 = 0
        if train:
            batch = self.batch // 2 if self.augment.available() else self.batch
        else:
            batch = self.batch

        for i in range(batch * 10):
            if train:
                line = self.train_list[self.train_offset]
                self.train_offset += 1
                if self.train_offset > len(self.train_list) - 1:
                    self.train_offset = 0
                    np.random.shuffle(self.train_list)
                    self.epochs += 1
            else:
                line = self.test_list[self.test_offset]
                self.test_offset += 1
                if self.test_offset > len(self.test_list) - 1:
                    self.test_offset = 0
            wp, txt = line.strip().split('\t')
            try:
                data = self.speech_featurizer.load_wav(wp)
            except:
                logging.info('{} load data failed,skip'.format(wp))
                continue
            if len(data) < 400:
                continue
            elif len(
                    data
            ) > self.speech_featurizer.sample_rate * self.speech_config[
                    'wav_max_duration']:
                logging.info(
                    '{} duration out of wav_max_duration({}),skip'.format(
                        wp, self.speech_config['wav_max_duration']))
                continue
            if self.speech_config['only_chinese']:
                txt = self.only_chinese(txt)
            if self.speech_config['use_mel_layer']:
                if not self.streaming:
                    speech_feature = data / np.abs(data).max()
                    speech_feature = np.expand_dims(speech_feature, -1)
                    in_len = len(speech_feature) // (
                        self.speech_config['reduction_factor'] *
                        (self.speech_featurizer.sample_rate / 1000) *
                        self.speech_config['stride_ms'])
                else:
                    speech_feature = data
                    speech_feature = np.expand_dims(speech_feature, -1)
                    reduce = self.speech_config['reduction_factor'] * (self.speech_featurizer.sample_rate / 1000) * \
                             self.speech_config['stride_ms']
                    in_len = len(speech_feature) // self.chunk
                    if len(speech_feature) % self.chunk != 0:
                        in_len += 1
                    chunk_times = self.chunk // reduce
                    if self.chunk % reduce != 0:
                        chunk_times += 1
                    in_len *= chunk_times
            else:
                speech_feature = self.speech_featurizer.extract(data)
                in_len = int(speech_feature.shape[0] //
                             self.speech_config['reduction_factor'])

            py = self.text_to_vocab(txt)
            if self.check_valid(py,
                                self.text_featurizer.vocab_array) is not True:
                logging.info(
                    ' {} txt pinyin {} not all in tokens,continue'.format(
                        txt,
                        self.check_valid(py,
                                         self.text_featurizer.vocab_array)))
                continue
            text_feature = self.text_featurizer.extract(py)

            if in_len < len(text_feature):
                logging.info(
                    '{} feature length < pinyin length,continue'.format(wp))
                continue
            max_input = max(max_input, len(speech_feature))
            max_label1 = max(max_label1, len(text_feature))
            speech_features.append(speech_feature)
            input_length.append(in_len)
            y1.append(np.array(text_feature))
            label_length1.append(len(text_feature))
            sample.append(line)
            if len(sample) == batch:
                break
        if train and self.augment.available():
            for i in sample:
                wp, txt = i.strip().split('\t')
                try:
                    data = self.speech_featurizer.load_wav(wp)
                except:
                    continue
                if len(data) < 400:
                    logging.info('{} wav too short < 25ms,skip'.format(wp))
                    continue
                elif len(
                        data
                ) > self.speech_featurizer.sample_rate * self.speech_config[
                        'wav_max_duration']:
                    continue
                data = self.augment.process(data)
                if self.speech_config['only_chinese']:
                    txt = self.only_chinese(txt)
                if self.speech_config['use_mel_layer']:
                    if not self.streaming:
                        speech_feature = data / np.abs(data).max()
                        speech_feature = np.expand_dims(speech_feature, -1)
                        in_len = len(speech_feature) // (
                            self.speech_config['reduction_factor'] *
                            (self.speech_featurizer.sample_rate / 1000) *
                            self.speech_config['stride_ms'])
                    else:
                        speech_feature = data
                        speech_feature = np.expand_dims(speech_feature, -1)
                        reduce = self.speech_config['reduction_factor'] * (self.speech_featurizer.sample_rate / 1000) * \
                                 self.speech_config['stride_ms']
                        in_len = len(speech_feature) // self.chunk
                        if len(speech_feature) % self.chunk != 0:
                            in_len += 1
                        chunk_times = self.chunk // reduce
                        if self.chunk % reduce != 0:
                            chunk_times += 1
                        in_len *= chunk_times
                else:
                    speech_feature = self.speech_featurizer.extract(data)
                    in_len = int(speech_feature.shape[0] //
                                 self.speech_config['reduction_factor'])

                py = self.text_to_vocab(txt)
                if not self.check_valid(py, self.text_featurizer.vocab_array):
                    continue

                text_feature = self.text_featurizer.extract(py)

                if in_len < len(text_feature):
                    continue
                max_input = max(max_input, len(speech_feature))
                max_label1 = max(max_label1, len(text_feature))
                speech_features.append(speech_feature)

                input_length.append(in_len)
                y1.append(np.array(text_feature))
                label_length1.append(len(text_feature))

        if self.speech_config['use_mel_layer']:
            if self.streaming:
                reduce = self.speech_config['reduction_factor'] * (self.speech_featurizer.sample_rate / 1000) * \
                         self.speech_config['stride_ms']
                max_input = max_input // self.chunk * self.chunk + self.chunk
                max_in_len = max_input // self.chunk
                chunk_times = self.chunk // reduce
                if self.chunk % reduce != 0:
                    chunk_times += 1
                max_in_len *= chunk_times
                input_length = np.clip(input_length, 0, max_in_len)
            speech_features = self.speech_featurizer.pad_signal(
                speech_features, max_input)

        else:
            for i in range(len(speech_features)):

                if speech_features[i].shape[0] < max_input:
                    pad = np.ones([
                        max_input - speech_features[i].shape[0],
                        speech_features[i].shape[1],
                        speech_features[i].shape[2]
                    ]) * speech_features[i].min()
                    speech_features[i] = np.vstack((speech_features[i], pad))

        for i in range(len(y1)):
            if y1[i].shape[0] < max_label1:
                pad = np.ones(max_label1 -
                              y1[i].shape[0]) * self.text_featurizer.pad
                y1[i] = np.hstack((y1[i], pad))

        x = np.array(speech_features, 'float32')
        y1 = np.array(y1, 'int32')

        input_length = np.array(input_length, 'int32')
        label_length1 = np.array(label_length1, 'int32')

        return x, input_length, y1, label_length1

    def generator(self, train=True):
        while 1:
            x, input_length, labels, label_length = self.generate(train)
            if x.shape[0] == 0:
                logging.info('load data length zero,continue')
                continue
            if self.LAS:
                guide_matrix = self.guided_attention(input_length,
                                                     label_length,
                                                     np.max(input_length),
                                                     label_length.max())
                yield x, input_length, labels, label_length, guide_matrix
            else:
                yield x, input_length, labels, label_length
Exemplo n.º 7
0
class AM():
    def __init__(self, config):
        self.config = config
        self.update_model_type()
        self.speech_config = self.config['speech_config']
        if self.model_type != 'MultiTask':
            self.text_config = self.config['decoder_config']
        else:
            self.text_config = self.config['decoder3_config']
        self.model_config = self.config['model_config']
        self.text_feature = TextFeaturizer(self.text_config, True)
        self.speech_feature = SpeechFeaturizer(self.speech_config)

        self.init_steps = None

    def update_model_type(self):
        if 'Streaming' in self.config['model_config']['name']:
            assert self.config['speech_config']['streaming'] is True
            assert 'Conformer' in self.config['model_config']['name']
        else:
            assert self.config['speech_config']['streaming'] is False
        if 'CTC' in self.config['model_config'][
                'name'] and 'Multi' not in self.config['model_config']['name']:
            self.config['decoder_config'].update({'model_type': 'CTC'})
            self.model_type = 'CTC'
        elif 'Multi' in self.config['model_config']['name']:
            self.config['decoder1_config'].update({'model_type': 'CTC'})
            self.config['decoder2_config'].update({'model_type': 'CTC'})
            self.config['decoder3_config'].update({'model_type': 'CTC'})
            self.config['decoder_config'].update({'model_type': 'CTC'})
            self.model_type = 'MultiTask'
        elif 'LAS' in self.config['model_config']['name']:
            self.config['decoder_config'].update({'model_type': 'LAS'})
            self.model_type = 'LAS'
        else:
            self.config['decoder_config'].update({'model_type': 'Transducer'})
            self.model_type = 'Transducer'

    def conformer_model(self, training):
        from AMmodel.streaming_conformer import StreamingConformerCTC, StreamingConformerTransducer
        from AMmodel.conformer import ConformerCTC, ConformerLAS, ConformerTransducer
        self.model_config.update(
            {'vocabulary_size': self.text_feature.num_classes})
        if self.model_config['name'] == 'ConformerTransducer':
            self.model_config.pop('LAS_decoder')
            self.model_config.pop('enable_tflite_convertible')
            self.model_config.update({'speech_config': self.speech_config})
            self.model = ConformerTransducer(**self.model_config)
        elif self.model_config['name'] == 'ConformerCTC':
            self.model_config.update({'speech_config': self.speech_config})
            self.model = ConformerCTC(**self.model_config)
        elif self.model_config['name'] == 'ConformerLAS':
            self.config['model_config']['LAS_decoder'].update(
                {'n_classes': self.text_feature.num_classes})
            self.config['model_config']['LAS_decoder'].update(
                {'startid': self.text_feature.start})
            self.model = ConformerLAS(
                self.config['model_config'],
                training=training,
                enable_tflite_convertible=self.config['model_config']
                ['enable_tflite_convertible'],
                speech_config=self.speech_config)
        elif self.model_config['name'] == 'StreamingConformerCTC':
            self.model_config.update({'speech_config': self.speech_config})
            self.model = StreamingConformerCTC(**self.model_config)
        elif self.model_config['name'] == 'StreamingConformerTransducer':
            self.model_config.pop('enable_tflite_convertible')
            self.model_config.update({'speech_config': self.speech_config})
            self.model = StreamingConformerTransducer(**self.model_config)
        else:
            raise ('not in supported model list')

    def ds2_model(self, training):
        from AMmodel.deepspeech2 import DeepSpeech2CTC, DeepSpeech2LAS, DeepSpeech2Transducer
        self.model_config['Transducer_decoder'][
            'vocabulary_size'] = self.text_feature.num_classes
        f, c = self.speech_feature.compute_feature_dim()
        input_shape = [None, f, c]
        self.model_config.update({'input_shape': input_shape})
        self.model_config.update(
            {'dmodel': self.model_config['rnn_conf']['rnn_units']})
        if self.model_config['name'] == 'DeepSpeech2Transducer':
            self.model_config.pop('LAS_decoder')
            self.model_config.pop('enable_tflite_convertible')
            self.model = DeepSpeech2Transducer(
                input_shape,
                self.model_config,
                speech_config=self.speech_config)
        elif self.model_config['name'] == 'DeepSpeech2CTC':
            self.model = DeepSpeech2CTC(input_shape,
                                        self.model_config,
                                        self.text_feature.num_classes,
                                        speech_config=self.speech_config)
        elif self.model_config['name'] == 'DeepSpeech2LAS':
            self.model_config['LAS_decoder'].update(
                {'n_classes': self.text_feature.num_classes})
            self.model_config['LAS_decoder'].update(
                {'startid': self.text_feature.start})
            self.model = DeepSpeech2LAS(
                self.model_config,
                input_shape,
                training=training,
                enable_tflite_convertible=self.
                model_config['enable_tflite_convertible'],
                speech_config=self.speech_config)
        else:
            raise ('not in supported model list')

    def multi_task_model(self, training):
        from AMmodel.MultiConformer import ConformerMultiTaskCTC
        token1_feature = TextFeaturizer(self.config['decoder1_config'])
        token2_feature = TextFeaturizer(self.config['decoder2_config'])
        token3_feature = TextFeaturizer(self.config['decoder3_config'])

        self.model_config.update({
            'classes1': token1_feature.num_classes,
            'classes2': token2_feature.num_classes,
            'classes3': token3_feature.num_classes,
        })

        self.model = ConformerMultiTaskCTC(self.model_config,
                                           training=training,
                                           speech_config=self.speech_config)

    def load_model(self, training=True):

        if 'Multi' in self.model_config['name']:
            self.multi_task_model(training)

        elif 'Conformer' in self.model_config['name']:
            self.conformer_model(training)
        else:
            self.ds2_model(training)
        self.model.add_featurizers(self.text_feature)
        f, c = self.speech_feature.compute_feature_dim()

        if not training:
            if self.text_config['model_type'] != 'LAS':
                if self.model.mel_layer is not None:
                    self.model._build([
                        3, 16000 if self.speech_config['streaming'] is False
                        else self.model.chunk_size * 2, 1
                    ])
                    self.model.return_pb_function([None, None, 1])
                else:
                    self.model._build([3, 80, f, c])
                    self.model.return_pb_function([None, None, f, c])

            else:
                if self.model.mel_layer is not None:
                    self.model._build([
                        3, 16000 if self.speech_config['streaming'] is False
                        else self.model.chunk_size * 2, 1
                    ], training)
                    self.model.return_pb_function([None, None, 1])
                else:

                    self.model._build([2, 80, f, c], training)
                    self.model.return_pb_function([None, None, f, c])

            self.load_checkpoint(self.config)

    def convert_to_pb(self, export_path):
        import tensorflow as tf
        concrete_func = self.model.recognize_pb.get_concrete_function()
        tf.saved_model.save(self.model, export_path, signatures=concrete_func)

    def decode_result(self, word):
        de = []
        for i in word:
            if i != self.text_feature.stop:
                de.append(self.text_feature.index_to_token[int(i)])
            else:
                break
        return de

    def predict(self, fp):
        if '.pcm' in fp:
            data = np.fromfile(fp, 'int16')
            data = np.array(data, 'float32')
            data /= 32768
        else:
            data = self.speech_feature.load_wav(fp)
        if self.model.mel_layer is None:
            mel = self.speech_feature.extract(data)
            mel = np.expand_dims(mel, 0)

            input_length = np.array(
                [[mel.shape[1] // self.model.time_reduction_factor]], 'int32')
        else:
            mel = data.reshape([1, -1, 1])
            input_length = np.array([[
                mel.shape[1] // self.model.time_reduction_factor //
                (self.speech_config['sample_rate'] *
                 self.speech_config['stride_ms'] / 1000)
            ]], 'int32')
        if self.speech_config['streaming']:
            chunk_size = self.model.chunk_size
            if mel.shape[1] % chunk_size != 0:
                T = mel.shape[1] // chunk_size * chunk_size + chunk_size
                pad_T = T - mel.shape[1]

                mel = np.hstack((mel, np.zeros([1, pad_T, 1])))
            mel = mel.reshape([1, -1, chunk_size, 1])
            mel = mel.astype('float32')
            if 'CTC' in self.model_type:
                enc_outputs = None
                result = []
                for i in range(mel.shape[1]):
                    input_wav = mel[:, i]

                    es = time.time()
                    enc_output = self.model.extract_feature(input_wav)
                    ee = time.time()
                    enc_output = enc_output.numpy()
                    if enc_outputs is not None:
                        enc_outputs = np.hstack((enc_outputs, enc_output))
                    else:
                        enc_outputs = enc_output

                    ds = time.time()
                    result_ = self.model.ctc_decode(
                        enc_outputs, np.array([[enc_outputs.shape[1]]],
                                              'int32'))
                    de = time.time()
                    print('extract cost time:', ee - es, 'ctc decode time:',
                          de - ds)
                    result_ = result_.numpy()[0]
                for n in result_:
                    if n != -1:
                        result.append(n)
                result = np.array(result)
                result = np.expand_dims(result, 0)
            else:
                states, result = self.model.initial_states(mel)
                enc_outputs = None
                start_B = 0
                for i in range(mel.shape[1]):
                    input_wav = mel[:, i]
                    es = time.time()
                    enc_output = self.model.extract_feature(input_wav)
                    ee = time.time()
                    enc_output = enc_output.numpy()
                    if enc_outputs is not None:
                        enc_outputs = np.hstack((enc_outputs, enc_output))
                    else:
                        enc_outputs = enc_output
                    ds = time.time()

                    result, states, start_B = self.model.perform_greedy(
                        enc_outputs, states, result, start_B)
                    de = time.time()
                    print('extract cost time:', ee - es, 'lstm decode time:',
                          de - ds, 'next start T:', int(start_B))
                result = result.numpy()
                result = np.hstack((result, np.ones([1])))
        else:

            result = self.model.recognize_pb(mel, input_length)[0]

        return result

    def load_checkpoint(self, config):
        """Load checkpoint."""

        self.checkpoint_dir = os.path.join(
            config['learning_config']['running_config']["outdir"],
            "checkpoints")
        files = os.listdir(self.checkpoint_dir)
        files.sort(key=lambda x: int(x.split('_')[-1].replace('.h5', '')))
        self.model.load_weights(os.path.join(self.checkpoint_dir, files[-1]))
        self.init_steps = int(files[-1].split('_')[-1].replace('.h5', ''))