def save(self): """ Saves the current model and related training parameters into a subdirectory of the checkpoint directory. The name of the subdirectory is the current local time in Y_M_D_H_M_S format. """ date_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime()) trainer_states = { 'optimizer': self.optimizer, 'trainset_list': self.trainset_list, 'validset': self.validset, 'epoch': self.epoch } torch.save(trainer_states, os.path.join(os.getcwd(), self.TRAINER_STATE_NAME)) torch.save(self.model, os.path.join(os.getcwd(), self.MODEL_NAME)) logger.info('save checkpoints\n%s\n%s' % (os.path.join(os.getcwd(), self.TRAINER_STATE_NAME), os.path.join(os.getcwd(), self.MODEL_NAME)))
def extract_noise(self, audio_path): try: signal = np.memmap(audio_path, dtype='h', mode='r').astype('float32') non_silence_indices = split(signal, top_db=30) for (start, end) in non_silence_indices: signal[start:end] = 0 noise = signal[signal != 0] return noise / 32767 except RuntimeError: logger.info("RuntimeError in {0}".format(audio_path)) return None except ValueError: logger.info("RuntimeError in {0}".format(audio_path)) return None
def _validate(self, model: nn.Module, queue: queue.Queue) -> float: """ Run training one epoch Args: model (torch.nn.Module): model to train queue (queue.Queue): validation queue, containing input, targets, input_lengths, target_lengths Returns: loss, cer - **loss** (float): loss of validation - **cer** (float): character error rate of validation """ target_list, predict_list = list(), list() cer = 1.0 model.eval() logger.info('validate() start') while True: inputs, targets, input_lengths, target_lengths = queue.get() if inputs.shape[0] == 0: break inputs = inputs.to(self.device) targets = targets[:, 1:].to(self.device) model.to(self.device) if isinstance(model, nn.DataParallel): y_hats = model.module.recognize(inputs, input_lengths) else: y_hats = model.recognize(inputs, input_lengths) for idx in range(targets.size(0)): target_list.append(self.vocab.label_to_string(targets[idx])) predict_list.append(self.vocab.label_to_string(y_hats[idx].cpu().detach().numpy())) cer = self.metric(targets, y_hats) self._save_result(target_list, predict_list) logger.info('validate() completed') return cer
def validate(self, model, queue): """ Run training one epoch Args: model (torch.nn.Module): model to train queue (queue.Queue): validation queue, containing input, targets, input_lengths, target_lengths Returns: loss, cer - **loss** (float): loss of validation - **cer** (float): character error rate of validation """ cer = 1.0 model.eval() logger.info('validate() start') with torch.no_grad(): while True: inputs, scripts, input_lengths, script_lengths = queue.get() if inputs.shape[0] == 0: break inputs = inputs.to(self.device) scripts = scripts.to(self.device) targets = scripts[:, 1:] model.module.flatten_parameters() output = model(inputs, input_lengths, teacher_forcing_ratio=0.0)[0] logit = torch.stack(output, dim=1).to(self.device) hypothesis = logit.max(-1)[1] cer = self.metric(targets, hypothesis) del inputs, input_lengths, scripts, targets, output, logit, hypothesis logger.info('validate() completed') return cer
def search(self, model, queue, device, print_every): cer = 0 total_sent_num = 0 timestep = 0 model.eval() with torch.no_grad(): while True: inputs, scripts, input_lengths, target_lengths = queue.get() if inputs.shape[0] == 0: break inputs = inputs.to(device) scripts = scripts.to(device) targets = scripts[:, 1:] output, _ = model(inputs, input_lengths, teacher_forcing_ratio=0.0, language_model=self.language_model) logit = torch.stack(output, dim=1).to(device) hypothesis = logit.max(-1)[1] for idx in range(targets.size(0)): self.target_list.append( label_to_string(scripts[idx], id2char, EOS_token)) self.hypothesis_list.append( label_to_string(hypothesis[idx].cpu().detach().numpy(), id2char, EOS_token)) cer = self.metric(targets, hypothesis) total_sent_num += scripts.size(0) if timestep % print_every == 0: logger.info('cer: {:.2f}'.format(cer)) timestep += 1 return cer
def search(self, model: nn.Module, queue: Queue, device: str, print_every: int) -> float: cer = 0 total_sent_num = 0 timestep = 0 model.eval() with torch.no_grad(): while True: inputs, targets, input_lengths, target_lengths = queue.get() if inputs.shape[0] == 0: break inputs = inputs.to(device) targets = targets.to(device) output = model(inputs, input_lengths, teacher_forcing_ratio=0.0, language_model=self.language_model, return_decode_dict=False) logit = torch.stack(output, dim=1).to(device) y_hat = logit.max(-1)[1] for idx in range(targets.size(0)): self.target_list.append( label_to_string(targets[idx], id2char, EOS_token)) self.y_hats.append( label_to_string(y_hat[idx].cpu().detach().numpy(), id2char, EOS_token)) cer = self.metric(targets[:, 1:], y_hat) total_sent_num += targets.size(0) if timestep % print_every == 0: logger.info('cer: {:.2f}'.format(cer)) timestep += 1 return cer
def load(self, path): """ Loads a Checkpoint object that was previously saved to disk. Args: path (str): path to the checkpoint subdirectory Returns: checkpoint (Checkpoint): checkpoint object with fields copied from those stored on disk """ logger.info('load checkpoints\n%s\n%s' % (os.path.join(path, self.TRAINER_STATE_NAME), os.path.join(path, self.MODEL_NAME))) if torch.cuda.is_available(): resume_checkpoint = torch.load( os.path.join(path, self.TRAINER_STATE_NAME)) model = torch.load(os.path.join(path, self.MODEL_NAME)) else: resume_checkpoint = torch.load( os.path.join(path, self.TRAINER_STATE_NAME), map_location=lambda storage, loc: storage) model = torch.load(os.path.join(path, self.MODEL_NAME), map_location=lambda storage, loc: storage) if isinstance(model, ListenAttendSpell): if isinstance(model, nn.DataParallel): model.module.flatten_parameters( ) # make RNN parameters contiguous else: model.flatten_parameters() return Checkpoint( model=model, optimizer=resume_checkpoint['optimizer'], epoch=resume_checkpoint['epoch'], trainset_list=resume_checkpoint['trainset_list'], validset=resume_checkpoint['validset'], )
def search(self, model: nn.Module, queue: Queue, device: str, print_every: int) -> float: cer = 0 total_sent_num = 0 timestep = 0 if isinstance(model, nn.DataParallel): model = model.module model.eval() model.to(device) while True: inputs, targets, input_lengths, target_lengths = queue.get() if inputs.shape[0] == 0: break inputs = inputs.to(device) targets = targets.to(device) y_hats = model.recognize(inputs, input_lengths) for idx in range(targets.size(0)): self.target_list.append( self.vocab.label_to_string(targets[idx])) self.predict_list.append( self.vocab.label_to_string( y_hats[idx].cpu().detach().numpy())) cer = self.metric(targets[:, 1:], y_hats) total_sent_num += targets.size(0) if timestep % print_every == 0: logger.info('cer: {:.2f}'.format(cer)) timestep += 1 return cer
def validate(self, model: nn.Module, queue: queue.Queue) -> float: """ Run training one epoch Args: model (torch.nn.Module): model to train queue (queue.Queue): validation queue, containing input, targets, input_lengths, target_lengths Returns: loss, cer - **loss** (float): loss of validation - **cer** (float): character error rate of validation """ cer = 1.0 model.eval() logger.info('validate() start') while True: inputs, targets, input_lengths, target_lengths = queue.get() if inputs.shape[0] == 0: break inputs = inputs.to(self.device) targets = targets[:, 1:].to(self.device) model.to(self.device) if isinstance(model, nn.DataParallel): y_hats = model.module.greedy_decode(inputs, input_lengths, self.device) else: y_hats = model.greedy_decode(inputs, input_lengths, self.device) cer = self.metric(targets, y_hats) logger.info('validate() completed') return cer
def save(self): """ Saves the current model and related training parameters into a subdirectory of the checkpoint directory. The name of the subdirectory is the current local time in Y_M_D_H_M_S format. """ date_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime()) path = os.path.join(self.SAVE_PATH, self.CHECKPOINT_DIR_NAME, date_time) if os.path.exists(path): shutil.rmtree(path) # delete path dir & sub-files os.makedirs(path) trainer_states = { 'optimizer': self.optimizer, 'trainset_list': self.trainset_list, 'validset': self.validset, 'epoch': self.epoch } torch.save(trainer_states, os.path.join(path, self.TRAINER_STATE_NAME)) torch.save(self.model, os.path.join(path, self.MODEL_NAME)) logger.info('save checkpoints\n%s\n%s' % (os.path.join(path, self.TRAINER_STATE_NAME), os.path.join(path, self.MODEL_NAME)))
def parse_audio(self, audio_path: str, augment_method: int) -> Tensor: """ Parses audio. Args: audio_path (str): path of audio file augment_method (int): flag indication which augmentation method to use. Returns: feature_vector - **feature_vector** (torch.FloatTensor): feature from audio file. """ signal = load_audio(audio_path, self.del_silence, extension=self.audio_extension) if signal is None: logger.info("Audio is None : {0}".format(audio_path)) return None feature = self.transforms(signal) if self.normalize: feature -= feature.mean() feature /= np.std(feature) # Refer to "Sequence to Sequence Learning with Neural Network" paper if self.input_reverse: feature = feature[:, ::-1] feature = FloatTensor( np.ascontiguousarray(np.swapaxes(feature, 0, 1))) else: feature = FloatTensor(feature).transpose(0, 1) if augment_method == SpectrogramParser.SPEC_AUGMENT: feature = self.spec_augment(feature) return feature
def __init__(self, dataset_path, noiseset_size, sample_rate=16000, noise_level=0.7): if not os.path.exists(dataset_path): logger.info("Directory doesn`t exist: {0}".format(dataset_path)) raise IOError logger.info("Create Noise injector...") self.noiseset_size = noiseset_size self.sample_rate = sample_rate self.noise_level = noise_level self.audio_paths = self.create_audio_paths(dataset_path) self.dataset = self.create_noiseset(dataset_path) logger.info("Create Noise injector complete !!")
def evaluate(self, model): """ Evaluate a model on given dataset and return performance. """ logger.info('evaluate() start') eval_queue = queue.Queue(self.num_workers << 1) eval_loader = AudioLoader(self.dataset, eval_queue, self.batch_size, 0) eval_loader.start() cer = self.decoder.search(model, eval_queue, self.device, self.print_every) self.decoder.save_result('../data/train_result/%s.csv' % type(self.decoder).__name__) logger.info('Evaluate CER: %s' % cer) logger.info('evaluate() completed') eval_loader.join()
def split_dataset(config: DictConfig, transcripts_path: str, vocab: Vocabulary): """ split into training set and validation set. Args: opt (ArgumentParser): set of options transcripts_path (str): path of transcripts Returns: train_batch_num, train_dataset_list, valid_dataset - **train_time_step** (int): number of time step for training - **trainset_list** (list): list of training dataset - **validset** (data_loader.MelSpectrogramDataset): validation dataset """ logger.info("split dataset start !!") trainset_list = list() if config.train.dataset == 'kspon': train_num = 620000 valid_num = 2545 elif config.train.dataset == 'libri': train_num = 281241 valid_num = 5567 else: raise NotImplementedError("Unsupported Dataset : {0}".format( config.train.dataset)) audio_paths, transcripts = load_dataset(transcripts_path) total_time_step = math.ceil(len(audio_paths) / config.train.batch_size) valid_time_step = math.ceil(valid_num / config.train.batch_size) train_time_step = total_time_step - valid_time_step train_audio_paths = audio_paths[:train_num + 1] train_transcripts = transcripts[:train_num + 1] valid_audio_paths = audio_paths[train_num + 1:] valid_transcripts = transcripts[train_num + 1:] if config.audio.spec_augment: train_time_step <<= 1 train_num_per_worker = math.ceil(train_num / config.train.num_workers) # audio_paths & script_paths shuffled in the same order # for seperating train & validation tmp = list(zip(train_audio_paths, train_transcripts)) random.shuffle(tmp) train_audio_paths, train_transcripts = zip(*tmp) # seperating the train dataset by the number of workers for idx in range(config.train.num_workers): train_begin_idx = train_num_per_worker * idx train_end_idx = min(train_num_per_worker * (idx + 1), train_num) trainset_list.append( SpectrogramDataset( train_audio_paths[train_begin_idx:train_end_idx], train_transcripts[train_begin_idx:train_end_idx], vocab.sos_id, vocab.eos_id, config=config, spec_augment=config.audio.spec_augment, dataset_path=config.train.dataset_path, audio_extension=config.audio.audio_extension, )) validset = SpectrogramDataset( audio_paths=valid_audio_paths, transcripts=valid_transcripts, sos_id=vocab.sos_id, eos_id=vocab.eos_id, config=config, spec_augment=False, dataset_path=config.train.dataset_path, audio_extension=config.audio.audio_extension, ) logger.info("split dataset complete !!") return train_time_step, trainset_list, validset
def _train_epoches( self, model: nn.Module, epoch: int, epoch_time_step: int, train_begin_time: float, queue: queue.Queue, teacher_forcing_ratio: float, ) -> Tuple[nn.Module, float, float]: """ Run training one epoch Args: model (torch.nn.Module): model to train epoch (int): number of current epoch epoch_time_step (int): total time step in one epoch train_begin_time (float): time of train begin queue (queue.Queue): training queue, containing input, targets, input_lengths, target_lengths teacher_forcing_ratio (float): teaching forcing ratio (default 0.99) Returns: loss, cer - **loss** (float): loss of current epoch - **cer** (float): character error rate of current epoch """ architecture = self.architecture if self.architecture == 'conformer': if isinstance(model, nn.DataParallel): architecture = 'conformer_t' if model.module.decoder is not None else 'conformer_ctc' else: architecture = 'conformer_t' if model.decoder is not None else 'conformer_ctc' cer = 1.0 epoch_loss_total = 0. total_num = 0 timestep = 0 model.train() begin_time = epoch_begin_time = time.time() num_workers = self.num_workers while True: inputs, targets, input_lengths, target_lengths = queue.get() if inputs.shape[0] == 0: # Empty feats means closing one loader num_workers -= 1 logger.debug('left train_loader: %d' % num_workers) if num_workers == 0: break else: continue self.optimizer.zero_grad() inputs = inputs.to(self.device) targets = targets.to(self.device) input_lengths = input_lengths.to(self.device) target_lengths = torch.as_tensor(target_lengths).to(self.device) model = model.to(self.device) output, loss, ctc_loss, cross_entropy_loss = self._model_forward( teacher_forcing_ratio=teacher_forcing_ratio, inputs=inputs, input_lengths=input_lengths, targets=targets, target_lengths=target_lengths, model=model, architecture=architecture, ) if architecture not in ('rnnt', 'conformer_t'): y_hats = output.max(-1)[1] cer = self.metric(targets[:, 1:], y_hats) loss.backward() self.optimizer.step(model) total_num += int(input_lengths.sum()) epoch_loss_total += loss.item() timestep += 1 torch.cuda.empty_cache() if timestep % self.print_every == 0: current_time = time.time() elapsed = current_time - begin_time epoch_elapsed = (current_time - epoch_begin_time) / 60.0 train_elapsed = (current_time - train_begin_time) / 3600.0 if architecture in ('rnnt', 'conformer_t'): logger.info( self.rnnt_log_format.format( timestep, epoch_time_step, loss, elapsed, epoch_elapsed, train_elapsed, self.optimizer.get_lr(), )) else: if self.joint_ctc_attention: logger.info( self.log_format.format( timestep, epoch_time_step, loss, ctc_loss, cross_entropy_loss, cer, elapsed, epoch_elapsed, train_elapsed, self.optimizer.get_lr(), )) else: logger.info( self.log_format.format( timestep, epoch_time_step, loss, cer, elapsed, epoch_elapsed, train_elapsed, self.optimizer.get_lr(), )) begin_time = time.time() if timestep % self.save_result_every == 0: self._save_step_result(self.train_step_result, epoch_loss_total / total_num, cer) if timestep % self.checkpoint_every == 0: Checkpoint(model, self.optimizer, self.trainset_list, self.validset, epoch).save() del inputs, input_lengths, targets, output, loss Checkpoint(model, self.optimizer, self.trainset_list, self.validset, epoch).save() logger.info('train() completed') return model, epoch_loss_total / total_num, cer
def train( self, model: nn.Module, # model to train batch_size: int, # batch size for experiment epoch_time_step: int, # number of time step for training num_epochs: int, # number of epochs (iteration) for training teacher_forcing_ratio: float = 0.99, # teacher forcing ratio resume: bool = False, # resume training with the latest checkpoint ) -> nn.Module: """ Run training for a given model. Args: model (torch.nn.Module): model to train batch_size (int): batch size for experiment epoch_time_step (int): number of time step for training num_epochs (int): number of epochs for training teacher_forcing_ratio (float): teacher forcing ratio (default 0.99) resume(bool, optional): resume training with the latest checkpoint, (default False) """ start_epoch = 0 if resume: checkpoint = Checkpoint() latest_checkpoint_path = checkpoint.get_latest_checkpoint() resume_checkpoint = checkpoint.load(latest_checkpoint_path) model = resume_checkpoint.model self.optimizer = resume_checkpoint.optimizer self.trainset_list = resume_checkpoint.trainset_list self.validset = resume_checkpoint.validset start_epoch = resume_checkpoint.epoch + 1 epoch_time_step = 0 for trainset in self.trainset_list: epoch_time_step += len(trainset) epoch_time_step = math.ceil(epoch_time_step / batch_size) logger.info('start') train_begin_time = time.time() for epoch in range(start_epoch, num_epochs): logger.info('Epoch %d start' % epoch) train_queue = queue.Queue(self.num_workers << 1) for trainset in self.trainset_list: trainset.shuffle() # Training train_loader = MultiDataLoader(self.trainset_list, train_queue, batch_size, self.num_workers, self.vocab.pad_id) train_loader.start() model, train_loss, train_cer = self._train_epoches( model=model, epoch=epoch, epoch_time_step=epoch_time_step, train_begin_time=train_begin_time, queue=train_queue, teacher_forcing_ratio=teacher_forcing_ratio, ) train_loader.join() Checkpoint(model, self.optimizer, self.trainset_list, self.validset, epoch).save() logger.info('Epoch %d (Training) Loss %0.4f CER %0.4f' % (epoch, train_loss, train_cer)) teacher_forcing_ratio -= self.teacher_forcing_step teacher_forcing_ratio = max(self.min_teacher_forcing_ratio, teacher_forcing_ratio) # Validation valid_queue = queue.Queue(self.num_workers << 1) valid_loader = AudioDataLoader(self.validset, valid_queue, batch_size, 0, self.vocab.pad_id) valid_loader.start() valid_cer = self._validate(model, valid_queue) valid_loader.join() logger.info('Epoch %d CER %0.4f' % (epoch, valid_cer)) self._save_epoch_result( train_result=[self.train_dict, train_loss, train_cer], valid_result=[self.valid_dict, train_loss, valid_cer]) logger.info( 'Epoch %d Training result saved as a csv file complete !!' % epoch) torch.cuda.empty_cache() Checkpoint(model, self.optimizer, self.trainset_list, self.validset, num_epochs).save() return model
def split_dataset(opt, audio_paths, script_paths): """ split into training set and validation set. Args: opt (ArgumentParser): set of options audio_paths (list): set of audio path script_paths (list): set of script path Returns: train_batch_num, train_dataset_list, valid_dataset - **train_time_step** (int): number of time step for training - **trainset_list** (list): list of training dataset - **validset** (data_loader.MelSpectrogramDataset): validation dataset """ target_dict = load_targets(script_paths) logger.info("split dataset start !!") trainset_list = list() train_num = math.ceil(len(audio_paths) * (1 - opt.valid_ratio)) total_time_step = math.ceil(len(audio_paths) / opt.batch_size) valid_time_step = math.ceil(total_time_step * opt.valid_ratio) train_time_step = total_time_step - valid_time_step base_time_step = train_time_step if opt.spec_augment: train_time_step += base_time_step if opt.noise_augment: train_time_step += base_time_step train_num_per_worker = math.ceil(train_num / opt.num_workers) # audio_paths & script_paths shuffled in the same order # for seperating train & validation tmp = list(zip(audio_paths, script_paths)) random.shuffle(tmp) audio_paths, script_paths = zip(*tmp) # seperating the train dataset by the number of workers for idx in range(opt.num_workers): train_begin_idx = train_num_per_worker * idx train_end_idx = min(train_num_per_worker * (idx + 1), train_num) trainset_list.append( SpectrogramDataset(audio_paths[train_begin_idx:train_end_idx], script_paths[train_begin_idx:train_end_idx], SOS_token, EOS_token, target_dict=target_dict, opt=opt, spec_augment=opt.spec_augment, noise_augment=opt.noise_augment, dataset_path=opt.dataset_path, noiseset_size=opt.noiseset_size, noise_level=opt.noise_level)) validset = SpectrogramDataset(audio_paths=audio_paths[train_num:], script_paths=script_paths[train_num:], sos_id=SOS_token, eos_id=EOS_token, target_dict=target_dict, opt=opt, spec_augment=False, noise_augment=False) logger.info("split dataset complete !!") return train_time_step, trainset_list, validset
def train( self, model: nn.Module, batch_size: int, epoch_time_step: int, num_epochs: int, teacher_forcing_ratio: float = 0.99, resume: bool = False ) -> nn.Module: """ Run training for a given model. Args: model (torch.nn.Module): model to train batch_size (int): batch size for experiment epoch_time_step (int): number of time step for training num_epochs (int): number of epochs for training teacher_forcing_ratio (float): teaching forcing ratio (default 0.99) resume(bool, optional): resume training with the latest checkpoint, (default False) """ start_epoch = 0 if resume: checkpoint = Checkpoint() latest_checkpoint_path = checkpoint.get_latest_checkpoint() resume_checkpoint = checkpoint.load(latest_checkpoint_path) model = resume_checkpoint.model self.optimizer = resume_checkpoint.optimizer self.trainset_list = resume_checkpoint.trainset_list self.validset = resume_checkpoint.validset start_epoch = resume_checkpoint.epoch + 1 epoch_time_step = 0 for trainset in self.trainset_list: epoch_time_step += len(trainset) epoch_time_step = math.ceil(epoch_time_step / batch_size) logger.info('start') train_begin_time = time.time() for epoch in range(start_epoch, num_epochs): logger.info('Epoch %d start' % epoch) train_queue = queue.Queue(self.num_workers << 1) for trainset in self.trainset_list: trainset.shuffle() # Training train_loader = MultiDataLoader(self.trainset_list, train_queue, batch_size, self.num_workers) train_loader.start() if epoch == 1: self.optimizer.set_lr(1e-04) elif epoch == 2: self.optimizer.set_lr(5e-05) elif epoch == 3: self.optimizer.set_scheduler(ReduceLROnPlateau(self.optimizer.optimizer, patience=1, factor=0.5), 999999) train_loss, train_cer = self.__train_epoches(model, epoch, epoch_time_step, train_begin_time, train_queue, teacher_forcing_ratio) train_loader.join() Checkpoint(model, self.optimizer, self.trainset_list, self.validset, epoch).save() logger.info('Epoch %d (Training) Loss %0.4f CER %0.4f' % (epoch, train_loss, train_cer)) teacher_forcing_ratio -= self.teacher_forcing_step teacher_forcing_ratio = max(self.min_teacher_forcing_ratio, teacher_forcing_ratio) # Validation valid_queue = queue.Queue(self.num_workers << 1) valid_loader = AudioDataLoader(self.validset, valid_queue, batch_size, 0) valid_loader.start() valid_loss, valid_cer = self.validate(model, valid_queue) valid_loader.join() if isinstance(self.optimizer.scheduler, ReduceLROnPlateau): self.optimizer.scheduler.step(valid_loss) logger.info('Epoch %d (Validate) Loss %0.4f CER %0.4f' % (epoch, valid_loss, valid_cer)) self.__save_epoch_result(train_result=[self.train_dict, train_loss, train_cer], valid_result=[self.valid_dict, valid_loss, valid_cer]) logger.info('Epoch %d Training result saved as a csv file complete !!' % epoch) Checkpoint(model, self.optimizer, self.criterion, self.trainset_list, self.validset, num_epochs).save() return model
def print_train_opts(opt): """ Print train options """ logger.info('--dataset_path: %s' % str(opt.dataset_path)) logger.info('--data_list_path: %s' % str(opt.data_list_path)) logger.info('--label_path: %s' % str(opt.label_path)) logger.info('--spec_augment: %s' % str(opt.spec_augment)) logger.info('--noise_augment: %s' % str(opt.noise_augment)) logger.info('--noiseset_size: %s' % str(opt.noiseset_size)) logger.info('--noise_level: %s' % str(opt.noise_level)) logger.info('--use_cuda: %s' % str(opt.use_cuda)) logger.info('--batch_size: %s' % str(opt.batch_size)) logger.info('--num_workers: %s' % str(opt.num_workers)) logger.info('--num_epochs: %s' % str(opt.num_epochs)) logger.info('--init_lr: %s' % str(opt.init_lr)) logger.info('--high_plateau_lr: %s' % str(opt.high_plateau_lr)) logger.info('--low_plateau_lr: %s' % str(opt.low_plateau_lr)) logger.info('--decay_threshold: %s' % str(opt.decay_threshold)) logger.info('--rampup_period: %s' % str(opt.rampup_period)) logger.info('--exp_decay_period: %s' % str(opt.exp_decay_period)) logger.info('--valid_ratio: %s' % str(opt.valid_ratio)) logger.info('--max_len: %s' % str(opt.max_len)) logger.info('--max_grad_norm: %s' % str(opt.max_grad_norm)) logger.info('--teacher_forcing_step: %s' % str(opt.teacher_forcing_step)) logger.info('--min_teacher_forcing_ratio: %s' % str(opt.min_teacher_forcing_ratio)) logger.info('--seed: %s' % str(opt.seed)) logger.info('--save_result_every: %s' % str(opt.save_result_every)) logger.info('--checkpoint_every: %s' % str(opt.checkpoint_every)) logger.info('--print_every: %s' % str(opt.print_every)) logger.info('--resume: %s' % str(opt.resume))
def train_epoches(self, model, epoch, epoch_time_step, train_begin_time, queue, teacher_forcing_ratio): """ Run training one epoch Args: model (torch.nn.Module): model to train epoch (int): number of current epoch epoch_time_step (int): total time step in one epoch train_begin_time (int): time of train begin queue (queue.Queue): training queue, containing input, targets, input_lengths, target_lengths teacher_forcing_ratio (float): teaching forcing ratio (default 0.99) Returns: loss, cer - **loss** (float): loss of current epoch - **cer** (float): character error rate of current epoch """ cer = 1.0 epoch_loss_total = 0. total_num = 0 timestep = 0 model.train() begin_time = epoch_begin_time = time.time() while True: inputs, scripts, input_lengths, target_lengths = queue.get() if inputs.shape[0] == 0: # Empty feats means closing one loader self.num_workers -= 1 logger.debug('left train_loader: %d' % self.num_workers) if self.num_workers == 0: break else: continue inputs = inputs.to(self.device) scripts = scripts.to(self.device) targets = scripts[:, 1:] model.module.flatten_parameters() output = model(inputs, input_lengths, scripts, teacher_forcing_ratio=teacher_forcing_ratio)[0] logit = torch.stack(output, dim=1).to(self.device) hypothesis = logit.max(-1)[1] loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), targets.contiguous().view(-1)) epoch_loss_total += loss.item() cer = self.metric(targets, hypothesis) total_num += int(input_lengths.sum()) self.optimizer.zero_grad() loss.backward() self.optimizer.step(model, loss.item()) timestep += 1 torch.cuda.empty_cache() if timestep % self.print_every == 0: current_time = time.time() elapsed = current_time - begin_time epoch_elapsed = (current_time - epoch_begin_time) / 60.0 train_elapsed = (current_time - train_begin_time) / 3600.0 logger.info( 'timestep: {:4d}/{:4d}, loss: {:.4f}, cer: {:.2f}, elapsed: {:.2f}s {:.2f}m {:.2f}h' .format(timestep, epoch_time_step, epoch_loss_total / total_num, cer, elapsed, epoch_elapsed, train_elapsed)) begin_time = time.time() if timestep % self.save_result_every == 0: self._save_step_result(self.train_step_result, epoch_loss_total / total_num, cer) if timestep % self.checkpoint_every == 0: Checkpoint(model, self.optimizer, self.criterion, self.trainset_list, self.validset, epoch).save() del inputs, input_lengths, scripts, targets, output, logit, loss, hypothesis Checkpoint(model, self.optimizer, self.criterion, self.trainset_list, self.validset, epoch).save() logger.info('train() completed') return epoch_loss_total / total_num, cer
def search(self, model: nn.Module, queue: Queue, device: str, print_every: int) -> float: cer = 0 total_sent_num = 0 timestep = 0 if isinstance(model, nn.DataParallel): model = model.module if isinstance(model, ListenAttendSpell): architecture = 'las' elif isinstance(model, SpeechTransformer): architecture = 'transformer' elif isinstance(model, DeepSpeech2): architecture = 'deepspeech2' else: raise ValueError("Unsupported model : {0}".format(type(model))) else: if isinstance(model, ListenAttendSpell): architecture = 'las' elif isinstance(model, SpeechTransformer): architecture = 'transformer' elif isinstance(model, DeepSpeech2): architecture = 'deepspeech2' else: raise ValueError("Unsupported model : {0}".format(type(model))) model.eval() model.to(device) while True: inputs, targets, input_lengths, target_lengths = queue.get() if inputs.shape[0] == 0: break inputs = inputs.to(device) targets = targets.to(device) if architecture == 'las': y_hats = model.greedy_search(inputs, input_lengths, device) elif architecture == 'transformer': y_hats = model.greedy_search(inputs, input_lengths, device) elif architecture == 'deepspeech2': y_hats = model.greedy_search(inputs, input_lengths, device) else: raise ValueError( "Unsupported model : {0}".format(architecture)) for idx in range(targets.size(0)): self.target_list.append( self.vocab.label_to_string(targets[idx])) self.predict_list.append( self.vocab.label_to_string( y_hats[idx].cpu().detach().numpy())) cer = self.metric(targets[:, 1:], y_hats) total_sent_num += targets.size(0) if timestep % print_every == 0: logger.info('cer: {:.2f}'.format(cer)) timestep += 1 return cer
def main(config: DictConfig) -> None: warnings.filterwarnings('ignore') logger.info(OmegaConf.to_yaml(config)) inference(config)
def main(config: DictConfig): warnings.filterwarnings('ignore') logger.info(OmegaConf.to_yaml(config)) train(config)
def print_train_opts(opt): """ Print train options """ logger.info('--spec_augment: %s' % str(opt.spec_augment)) logger.info('--use_cuda: %s' % str(opt.use_cuda)) logger.info('--batch_size: %s' % str(opt.batch_size)) logger.info('--num_workers: %s' % str(opt.num_workers)) logger.info('--num_epochs: %s' % str(opt.num_epochs)) logger.info('--init_lr: %s' % str(opt.init_lr)) logger.info('--warmup_steps: %s' % str(opt.warmup_steps)) logger.info('--max_len: %s' % str(opt.max_len)) logger.info('--max_grad_norm: %s' % str(opt.max_grad_norm)) logger.info('--teacher_forcing_step: %s' % str(opt.teacher_forcing_step)) logger.info('--min_teacher_forcing_ratio: %s' % str(opt.min_teacher_forcing_ratio)) logger.info('--seed: %s' % str(opt.seed)) logger.info('--save_result_every: %s' % str(opt.save_result_every)) logger.info('--checkpoint_every: %s' % str(opt.checkpoint_every)) logger.info('--print_every: %s' % str(opt.print_every)) logger.info('--resume: %s' % str(opt.resume))
def print_preprocess_opts(opt): """ Print preprocess options """ logger.info('--mode: %s' % str(opt.mode)) logger.info('--transform_method: %s' % str(opt.transform_method)) logger.info('--sample_rate: %s' % str(opt.sample_rate)) logger.info('--frame_length: %s' % str(opt.frame_length)) logger.info('--frame_shift: %s' % str(opt.frame_shift)) logger.info('--n_mels: %s' % str(opt.n_mels)) logger.info('--normalize: %s' % str(opt.normalize)) logger.info('--del_silence: %s' % str(opt.del_silence)) logger.info('--input_reverse: %s' % str(opt.input_reverse)) logger.info('--feature_extract_by: %s' % str(opt.feature_extract_by)) logger.info('--time_mask_para: %s' % str(opt.time_mask_para)) logger.info('--freq_mask_para: %s' % str(opt.freq_mask_para)) logger.info('--time_mask_num: %s' % str(opt.time_mask_num)) logger.info('--freq_mask_num: %s' % str(opt.freq_mask_num))
def main(config: DictConfig) -> None: warnings.filterwarnings('ignore') logger.info(OmegaConf.to_yaml(config)) last_model_checkpoint = train(config) torch.save(last_model_checkpoint, os.path.join(os.getcwd(), "last_model_checkpoint.pt"))
def print_model_opts(opt): """ Print model options """ logger.info('--architecture: %s' % str(opt.architecture)) logger.info('--use_bidirectional: %s' % str(opt.use_bidirectional)) logger.info('--mask_conv: %s' % str(opt.mask_conv)) logger.info('--hidden_dim: %s' % str(opt.hidden_dim)) logger.info('--dropout: %s' % str(opt.dropout)) logger.info('--attn_mechanism: %s' % str(opt.attn_mechanism)) logger.info('--num_heads: %s' % str(opt.num_heads)) logger.info('--label_smoothing: %s' % str(opt.label_smoothing)) logger.info('--num_encoder_layers: %s' % str(opt.num_encoder_layers)) logger.info('--num_decoder_layers: %s' % str(opt.num_decoder_layers)) logger.info('--extractor: %s' % str(opt.extractor)) logger.info('--activation: %s' % str(opt.activation)) logger.info('--rnn_type: %s' % str(opt.rnn_type)) logger.info('--teacher_forcing_ratio: %s' % str(opt.teacher_forcing_ratio))
def __train_epoches(self, model: nn.Module, epoch: int, epoch_time_step: int, train_begin_time: float, queue: queue.Queue, teacher_forcing_ratio: float) -> Tuple[float, float]: """ Run training one epoch Args: model (torch.nn.Module): model to train epoch (int): number of current epoch epoch_time_step (int): total time step in one epoch train_begin_time (float): time of train begin queue (queue.Queue): training queue, containing input, targets, input_lengths, target_lengths teacher_forcing_ratio (float): teaching forcing ratio (default 0.99) Returns: loss, cer - **loss** (float): loss of current epoch - **cer** (float): character error rate of current epoch """ cer = 1.0 epoch_loss_total = 0. total_num = 0 timestep = 0 model.train() begin_time = epoch_begin_time = time.time() num_workers = self.num_workers while True: inputs, targets, input_lengths, target_lengths = queue.get() if inputs.shape[0] == 0: # Empty feats means closing one loader num_workers -= 1 logger.debug('left train_loader: %d' % num_workers) if num_workers == 0: break else: continue inputs = inputs.to(self.device) targets = targets.to(self.device) model = model.to(self.device) if self.architecture == 'las': if isinstance(model, nn.DataParallel): model.module.flatten_parameters() else: model.flatten_parameters() logit = model(inputs=inputs, input_lengths=input_lengths, targets=targets, teacher_forcing_ratio=teacher_forcing_ratio) logit = torch.stack(logit, dim=1).to(self.device) targets = targets[:, 1:] elif self.architecture == 'transformer': logit = model(inputs, input_lengths, targets, return_attns=False) else: raise ValueError("Unsupported architecture : {0}".format( self.architecture)) hypothesis = logit.max(-1)[1] loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), targets.contiguous().view(-1)) epoch_loss_total += loss.item() cer = self.metric(targets, hypothesis) total_num += int(input_lengths.sum()) self.optimizer.zero_grad() loss.backward() self.optimizer.step(model) timestep += 1 torch.cuda.empty_cache() if timestep % self.print_every == 0: current_time = time.time() elapsed = current_time - begin_time epoch_elapsed = (current_time - epoch_begin_time) / 60.0 train_elapsed = (current_time - train_begin_time) / 3600.0 logger.info( 'timestep: {:4d}/{:4d}, loss: {:.4f}, cer: {:.2f}, elapsed: {:.2f}s {:.2f}m {:.2f}h, lr: {:.5f}' .format(timestep, epoch_time_step, epoch_loss_total / total_num, cer, elapsed, epoch_elapsed, train_elapsed, self.optimizer.get_lr())) begin_time = time.time() if timestep % self.save_result_every == 0: self.__save_step_result(self.train_step_result, epoch_loss_total / total_num, cer) if timestep % self.checkpoint_every == 0: Checkpoint(model, self.optimizer, self.trainset_list, self.validset, epoch).save() del inputs, input_lengths, targets, logit, loss, hypothesis Checkpoint(model, self.optimizer, self.trainset_list, self.validset, epoch).save() logger.info('train() completed') return epoch_loss_total / total_num, cer
def print_eval_opts(opt): """ Print evaltation options """ logger.info('--dataset_path: %s' % str(opt.dataset_path)) logger.info('--data_list_path: %s' % str(opt.data_list_path)) logger.info('--label_path: %s' % str(opt.label_path)) logger.info('--num_workers: %s' % str(opt.num_workers)) logger.info('--use_cuda: %s' % str(opt.use_cuda)) logger.info('--model_path: %s' % str(opt.model_path)) logger.info('--batch_size: %s' % str(opt.batch_size)) logger.info('--decode: %s' % str(opt.decode)) logger.info('--k: %s' % str(opt.k)) logger.info('--print_every: %s' % str(opt.print_every))
def train(self, model, batch_size, epoch_time_step, num_epochs, teacher_forcing_ratio=0.99, resume=False): """ Run training for a given model. Args: model (torch.nn.Module): model to train batch_size (int): batch size for experiment epoch_time_step (int): number of time step for training num_epochs (int): number of epochs for training teacher_forcing_ratio (float): teaching forcing ratio (default 0.99) resume(bool, optional): resume training with the latest checkpoint, (default False) """ start_epoch = 0 prev_train_cer = 1. if resume: checkpoint = Checkpoint() latest_checkpoint_path = checkpoint.get_latest_checkpoint() resume_checkpoint = checkpoint.load(latest_checkpoint_path) model = resume_checkpoint.model self.optimizer = resume_checkpoint.optimizer self.criterion = resume_checkpoint.criterion self.trainset_list = resume_checkpoint.trainset_list self.validset = resume_checkpoint.validset start_epoch = resume_checkpoint.epoch epoch_time_step = 0 for trainset in self.trainset_list: epoch_time_step += len(trainset) epoch_time_step = math.ceil(epoch_time_step / batch_size) epoch_time_step = 0 for trainset in self.trainset_list: epoch_time_step += len(trainset) epoch_time_step = math.ceil(epoch_time_step / batch_size) for g in self.optimizer.optimizer.param_groups: g['lr'] = 1e-04 print("Learning rate : %f", self.optimizer.get_lr()) logger.info('start') train_begin_time = time.time() for epoch in range(start_epoch, num_epochs): train_queue = queue.Queue(self.num_workers << 1) for trainset in self.trainset_list: trainset.shuffle() # Training train_loader = MultiAudioLoader(self.trainset_list, train_queue, batch_size, self.num_workers) train_loader.start() train_loss, train_cer = self.train_epoches(model, epoch, epoch_time_step, train_begin_time, train_queue, teacher_forcing_ratio) train_loader.join() Checkpoint(model, self.optimizer, self.criterion, self.trainset_list, self.validset, epoch).save() logger.info('Epoch %d (Training) Loss %0.4f CER %0.4f' % (epoch, train_loss, train_cer)) if prev_train_cer - train_cer < self.decay_threshold: self.optimizer.set_scheduler( ExponentialDecayLR(self.optimizer.optimizer, self.optimizer.get_lr(), self.low_plateau_lr, self.exp_decay_period), self.exp_decay_period) prev_train_cer = train_cer teacher_forcing_ratio -= self.teacher_forcing_step teacher_forcing_ratio = max(self.min_teacher_forcing_ratio, teacher_forcing_ratio) # Validation valid_queue = queue.Queue(self.num_workers << 1) valid_loader = AudioLoader(self.validset, valid_queue, batch_size, 0) valid_loader.start() valid_cer = self.validate(model, valid_queue) valid_loader.join() logger.info('Epoch %d (Validate) Loss %0.4f CER %0.4f' % (epoch, 0.0, valid_cer)) self._save_epoch_result( train_result=[self.train_dict, train_loss, train_cer], valid_result=[self.valid_dict, 0.0, valid_cer]) logger.info( 'Epoch %d Training result saved as a csv file complete !!' % epoch) return model