def __init__(self,config): self.config=config['learning_config'] self.am = AM(config) self.am.load_model(training=True) if self.am.model_type!='MultiTask': self.dg = AM_DataLoader(config) else: self.dg=MultiTask_DataLoader(config) self.dg.speech_config['reduction_factor']=self.am.model.time_reduction_factor self.dg.load_state(self.config['running_config']['outdir']) if self.am.model_type=='CTC': self.runner = ctc_runners.CTCTrainer(self.dg.speech_featurizer,self.dg.text_featurizer,self.config['running_config']) elif self.am.model_type=='LAS': self.runner=las_runners.LASTrainer(self.dg.speech_featurizer,self.dg.text_featurizer,self.config['running_config']) self.dg.LAS=True elif self.am.model_type == 'MultiTask': self.runner = multi_runners.MultiTaskLASTrainer(self.dg.speech_featurizer, self.dg.token4_featurizer, self.config['running_config']) else: self.runner = transducer_runners.TransducerTrainer(self.dg.speech_featurizer,self.dg.text_featurizer,self.config['running_config'] ) self.STT = self.am.model if self.dg.augment.available(): factor=2 else: factor=1 self.opt = tf.keras.optimizers.Adamax(**config['optimizer_config']) self.runner.set_total_train_steps(self.dg.get_per_epoch_steps() * self.config['running_config']['num_epochs']*factor) self.runner.compile(self.STT,self.opt) self.dg.batch=self.runner.global_batch_size
def __init__(self, config): self.config = config['learning_config'] self.config['running_config'].update( {'streaming': config['speech_config']['streaming']}) self.am = AM(config) self.am.load_model(training=True) if self.am.model_type != 'MultiTask': self.dg = AM_DataLoader(config) else: self.dg = MultiTask_DataLoader(config) self.dg.speech_config[ 'reduction_factor'] = self.am.model.time_reduction_factor self.dg.load_state(self.config['running_config']['outdir']) if self.am.model_type == 'CTC': self.runner = ctc_runners.CTCTrainer(self.dg.speech_featurizer, self.dg.text_featurizer, self.config['running_config']) elif self.am.model_type == 'LAS': self.runner = las_runners.LASTrainer(self.dg.speech_featurizer, self.dg.text_featurizer, self.config['running_config']) self.dg.LAS = True elif self.am.model_type == 'MultiTask': self.runner = multi_runners.MultiTaskCTCTrainer( self.dg.speech_featurizer, self.config['running_config']) else: self.runner = transducer_runners.TransducerTrainer( self.dg.speech_featurizer, self.dg.text_featurizer, self.config['running_config']) self.STT = self.am.model if self.dg.augment.available(): factor = 2 else: factor = 1 all_train_step = self.dg.get_per_epoch_steps( ) * self.config['running_config']['num_epochs'] * factor lr = CustomSchedule(config['model_config']['dmodel'], warmup_steps=int(all_train_step * 0.1)) config['optimizer_config']['learning_rate'] = lr self.opt = tf.keras.optimizers.Adam(**config['optimizer_config']) self.runner.set_total_train_steps(all_train_step) self.runner.compile(self.STT, self.opt) self.dg.batch = self.runner.global_batch_size
def __init__(self,config): self.config=config['learning_config'] self.am = AM(config) self.am.load_model(training=False) f,c=self.am.speech_feature.compute_feature_dim() self.am.model.return_pb_function(f,c) if self.am.model_type!='MultiTask': self.dg = AM_DataLoader(config,training=False) self.runner = am_tester.AMTester(self.config['running_config'], self.dg.text_featurizer) else: self.dg=MultiTask_DataLoader(config,training=False) self.runner=multi_task_tester.MultiTaskTester(self.config['running_config'],self.dg.token3_featurizer,self.dg.token4_featurizer) self.STT = self.am.model self.runner.set_progbar(self.dg.eval_per_epoch_steps()) self.runner.compile(self.STT)
def __init__(self, config): self.config = config['learning_config'] self.am = AM(config) self.am.load_model(training=False) if self.am.model_type != 'MultiTask': self.dg = AM_DataLoader(config, training=False) self.runner = am_tester.AMTester( self.config['running_config'], self.dg.text_featurizer, streaming=config['speech_config']['streaming']) else: self.dg = MultiTask_DataLoader(config, training=False) self.runner = multi_task_tester.MultiTaskTester( self.config['running_config'], self.dg.token3_featurizer) self.STT = self.am.model self.runner.set_progbar(self.dg.eval_per_epoch_steps()) self.runner.set_all_steps(self.dg.eval_per_epoch_steps()) self.runner.compile(self.STT)