class ASR(): def __init__(self, am_config, lm_config): self.am = AM(am_config) self.am.load_model(False) self.lm = LM(lm_config) self.lm.load_model(False) def decode_am_result(self, result): return self.am.decode_result(result) def stt(self, wav_path): am_result = self.am.predict(wav_path) if self.am.model_type == 'Transducer': am_result = self.decode_am_result(am_result[1:-1]) lm_result = self.lm.predict(am_result) lm_result = self.lm.decode(lm_result[0].numpy(), self.lm.lm_featurizer) else: am_result = self.decode_am_result(am_result[0]) lm_result = self.lm.predict(am_result) lm_result = self.lm.decode(lm_result[0].numpy(), self.lm.lm_featurizer) return am_result, lm_result def am_test(self, wav_path): # am_result is token id am_result = self.am.predict(wav_path) # token to vocab if self.am.model_type == 'Transducer': am_result = self.decode_am_result(am_result[1:-1]) else: am_result = self.decode_am_result(am_result[0]) return am_result def lm_test(self, txt): if self.lm.config['am_token']['for_multi_task']: pys = pypinyin.pinyin(txt, 8, neutral_tone_with_five=True) input_py = [i[0] for i in pys] else: pys = pypinyin.pinyin(txt) input_py = [i[0] for i in pys] # now lm_result is token id lm_result = self.lm.predict(input_py) # token to vocab lm_result = self.lm.decode(lm_result[0].numpy(), self.lm.lm_featurizer) return lm_result
class AM_Tester(): def __init__(self, config): self.config = config['learning_config'] self.am = AM(config) self.am.load_model(training=False) if self.am.model_type != 'MultiTask': self.dg = AM_DataLoader(config, training=False) self.runner = am_tester.AMTester( self.config['running_config'], self.dg.text_featurizer, streaming=config['speech_config']['streaming']) else: self.dg = MultiTask_DataLoader(config, training=False) self.runner = multi_task_tester.MultiTaskTester( self.config['running_config'], self.dg.token3_featurizer) self.STT = self.am.model self.runner.set_progbar(self.dg.eval_per_epoch_steps()) self.runner.set_all_steps(self.dg.eval_per_epoch_steps()) self.runner.compile(self.STT) def make_eval_batch_data(self): batches = [] for _ in range( self.config['running_config']['eval_steps_per_batches']): if self.am.model_type != 'MultiTask': features, input_length, labels, label_length = self.dg.eval_data_generator( ) input_length = np.expand_dims(input_length, -1) batches.append((features, input_length, labels, label_length)) else: speech_features, input_length, words_label, words_label_length, phone_label, phone_label_length, py_label, py_label_length = self.dg.eval_data_generator( ) input_length = np.expand_dims(input_length, -1) batches.append((speech_features, input_length, py_label)) return batches def test(self): while 1: eval_batches = self.make_eval_batch_data() # print('now',self.dg.offset) self.runner.run(eval_batches) if self.dg.offset > len(self.dg.test_list) - 1: break
class ASR(): def __init__(self, am_config): self.am = AM(am_config) self.am.load_model(False) def decode_am_result(self, result): return self.am.decode_result(result) def am_test(self, wav_path): # am_result is token id am_result = self.am.predict(wav_path) # token to vocab if self.am.model_type == 'Transducer': am_result = self.decode_am_result(am_result[1:-1]) else: am_result = self.decode_am_result(am_result[0]) return am_result
class ASR(): def __init__(self, am_config, lm_config): self.am = AM(am_config) self.am.load_model(False) self.lm = LM(lm_config) self.lm.load_model() def decode_am_result(self, result): return self.am.decode_result(result[0]) def stt(self, wav_path): am_result = self.am.predict(wav_path) lm_result = self.lm.predict(self.decode_am_result(am_result)) return am_result, lm_result
class AM_Trainer(): def __init__(self, config): self.config = config['learning_config'] self.am = AM(config) self.am.load_model(training=True) if self.am.model_type != 'MultiTask': self.dg = AM_DataLoader(config) else: self.dg = MultiTask_DataLoader(config) self.dg.speech_config[ 'reduction_factor'] = self.am.model.time_reduction_factor self.dg.load_state(self.config['running_config']['outdir']) if self.am.model_type == 'CTC': self.runner = ctc_runners.CTCTrainer(self.dg.speech_featurizer, self.dg.text_featurizer, self.config['running_config']) elif self.am.model_type == 'LAS': self.runner = las_runners.LASTrainer(self.dg.speech_featurizer, self.dg.text_featurizer, self.config['running_config']) self.dg.LAS = True elif self.am.model_type == 'MultiTask': self.runner = multi_runners.MultiTaskLASTrainer( self.dg.speech_featurizer, self.dg.token4_featurizer, self.config['running_config']) else: self.runner = transducer_runners.TransducerTrainer( self.dg.speech_featurizer, self.dg.text_featurizer, self.config['running_config']) self.STT = self.am.model if self.dg.augment.available(): factor = 2 else: factor = 1 self.opt = tf.keras.optimizers.Adamax(**config['optimizer_config']) self.runner.set_total_train_steps( self.dg.get_per_epoch_steps() * self.config['running_config']['num_epochs'] * factor) self.runner.compile(self.STT, self.opt) self.dg.batch = self.runner.global_batch_size def load_checkpoint(self, config, model): """Load checkpoint.""" self.checkpoint_dir = os.path.join( config['learning_config']['running_config']["outdir"], "checkpoints") files = os.listdir(self.checkpoint_dir) files.sort(key=lambda x: int(x.split('_')[-1].replace('.h5', ''))) model.load_weights(os.path.join(self.checkpoint_dir, files[-1])) self.init_steps = int(files[-1].split('_')[-1].replace('.h5', '')) def recevie_data(self, r): data = r.rpop(self.config['data_name']) data = eval(data) trains = [] for key in self.config['data_dict_key']: x = data[key] dtype = data['%s_dtype' % key] shape = data['%s_shape' % key] x = np.frombuffer(x, dtype) x = x.reshape(shape) trains.append(x) return trains def train(self): if self.am.model_type != 'MultiTask': train_datasets = tf.data.Dataset.from_generator( self.dg.generator, self.dg.return_data_types(), self.dg.return_data_shape(), args=(True, )) eval_datasets = tf.data.Dataset.from_generator( self.dg.generator, self.dg.return_data_types(), self.dg.return_data_shape(), args=(False, )) self.runner.set_datasets(train_datasets, eval_datasets) else: self.runner.set_datasets(self.dg.generator(True), self.dg.generator(False)) while 1: self.runner.fit(epoch=self.dg.epochs) if self.runner._finished(): self.runner.save_checkpoint() logging.info('Finish training!') break if self.runner.steps % self.config['running_config'][ 'save_interval_steps'] == 0: self.dg.save_state(self.config['running_config']['outdir'])
class StreamingASR(object): def __init__(self, config): self.am = AM(config) self.am.load_model(False) self.speech_config = config['speech_config'] self.text_config = config['decoder_config'] self.speech_feature = SpeechFeaturizer(self.speech_config) self.text_featurizer = TextFeaturizer(self.text_config) self.decoded = tf.constant([self.text_featurizer.start]) def stream_detect(self, inputs): data = self.speech_feature.load_wav(inputs) if self.am.model.mel_layer is None: mel = self.speech_feature.extract(data) x = np.expand_dims(mel, 0) else: mel = data.reshape([1, -1, 1]) x = self.am.model.mel_layer(mel) x = self.am.model.encoder(x) # TensorShape([1, 109, 144]) step = x.shape[1] i = 0 while i < step: self.step_decode(x[:, i]) i = i+1 def step_decode(self, step_input): enc = tf.reshape(step_input, [1, 1, -1]) y = self.am.model.predict_net(inputs=tf.reshape(self.decoded, [1, -1]), p_memory_states=None, training=False) y = y[:, -1:] z = self.am.model.joint_net([enc, y], training=False) probs = tf.squeeze(tf.nn.log_softmax(z)) pred = tf.argmax(probs, axis=-1, output_type=tf.int32) pred = tf.reshape(pred, [1]) if pred != 0 and pred != self.text_featurizer.blank: self.decoded = tf.concat([self.decoded, pred], axis=0) print("pred: {}".format(self.text_featurizer.index_to_token[pred.numpy().tolist()[0]])) def predict_stack_buffer(self, wavfile): data = self.speech_feature.load_wav(wavfile) buffer_step = int(len(data) / 16000) j = 0 while j < buffer_step: buffer = data[j * 16000 - j * 5000: (j + 1) * 16000] if self.am.model.mel_layer is None: mel = self.speech_feature.extract(buffer) x = np.expand_dims(mel, 0) else: mel = buffer.reshape([1, -1, 1]) x = self.am.model.mel_layer(mel) x = self.am.model.encoder(x) step = x.shape[1] i = 0 while i < step: enc = tf.reshape(x[:, i], [1, 1, -1]) y = self.am.model.predict_net(inputs=tf.reshape(self.decoded, [1, -1]), p_memory_states=None, training=False) y = y[:, -1:] z = self.am.model.joint_net([enc, y], training=False) logits = tf.squeeze(tf.nn.log_softmax(z)) pred = tf.argmax(logits, axis=-1, output_type=tf.int32) pred = tf.reshape(pred, [1]) if pred != 0 and pred != self.text_featurizer.blank: self.decoded = tf.concat([self.decoded, pred], axis=0) print("buffer_step: {}, " "step: {}, " "pred: {}".format(j, i, self.text_featurizer.index_to_token[pred.numpy().tolist()[0]])) i += 1 j += 1 print(1)
class AM_Trainer(): def __init__(self, config): self.config = config['learning_config'] self.am = AM(config) self.am.load_model(training=True) if self.am.model_type != 'MultiTask': self.dg = AM_DataLoader(config) else: self.dg = MultiTask_DataLoader(config) self.dg.speech_config[ 'reduction_factor'] = self.am.model.time_reduction_factor self.dg.load_state(self.config['running_config']['outdir']) if self.am.model_type == 'CTC': self.runner = ctc_runners.CTCTrainer(self.dg.speech_featurizer, self.dg.text_featurizer, self.config['running_config']) elif self.am.model_type == 'LAS': self.runner = las_runners.LASTrainer(self.dg.speech_featurizer, self.dg.text_featurizer, self.config['running_config']) self.dg.LAS = True elif self.am.model_type == 'MultiTask': self.runner = multi_runners.MultiTaskLASTrainer( self.dg.speech_featurizer, self.dg.token4_featurizer, self.config['running_config']) else: self.runner = transducer_runners.TransducerTrainer( self.dg.speech_featurizer, self.dg.text_featurizer, self.config['running_config']) self.STT = self.am.model if self.dg.augment.available(): factor = 2 else: factor = 1 all_train_step = self.dg.get_per_epoch_steps( ) * self.config['running_config']['num_epochs'] * factor lr = CustomSchedule(config['model_config']['dmodel'], warmup_steps=int(all_train_step * 0.1)) config['optimizer_config']['learning_rate'] = lr self.opt = tf.keras.optimizers.Adamax(**config['optimizer_config']) self.runner.set_total_train_steps(all_train_step) self.runner.compile(self.STT, self.opt) self.dg.batch = self.runner.global_batch_size def recevie_data(self, r): data = r.rpop(self.config['data_name']) data = eval(data) trains = [] for key in self.config['data_dict_key']: x = data[key] dtype = data['%s_dtype' % key] shape = data['%s_shape' % key] x = np.frombuffer(x, dtype) x = x.reshape(shape) trains.append(x) return trains def train(self): if self.am.model_type != 'MultiTask': train_datasets = tf.data.Dataset.from_generator( self.dg.generator, self.dg.return_data_types(), self.dg.return_data_shape(), args=(True, )) eval_datasets = tf.data.Dataset.from_generator( self.dg.generator, self.dg.return_data_types(), self.dg.return_data_shape(), args=(False, )) self.runner.set_datasets(train_datasets, eval_datasets) else: self.runner.set_datasets(self.dg.generator(True), self.dg.generator(False)) while 1: self.runner.fit(epoch=self.dg.epochs) if self.runner._finished(): self.runner.save_checkpoint() logging.info('Finish training!') break if self.runner.steps % self.config['running_config'][ 'save_interval_steps'] == 0: self.dg.save_state(self.config['running_config']['outdir'])