def train(self, model: nn.Module, batch_size: int, epoch_time_step: int, num_epochs: int, teacher_forcing_ratio: float = 0.99, resume: bool = False) -> nn.Module: """ Run training for a given model. Args: model (torch.nn.Module): model to train batch_size (int): batch size for experiment epoch_time_step (int): number of time step for training num_epochs (int): number of epochs for training teacher_forcing_ratio (float): teaching forcing ratio (default 0.99) resume(bool, optional): resume training with the latest checkpoint, (default False) """ start_epoch = 0 if resume: checkpoint = Checkpoint() latest_checkpoint_path = checkpoint.get_latest_checkpoint() resume_checkpoint = checkpoint.load(latest_checkpoint_path) model = resume_checkpoint.model self.optimizer = resume_checkpoint.optimizer self.trainset_list = resume_checkpoint.trainset_list self.validset = resume_checkpoint.validset start_epoch = resume_checkpoint.epoch + 1 epoch_time_step = 0 for trainset in self.trainset_list: epoch_time_step += len(trainset) epoch_time_step = math.ceil(epoch_time_step / batch_size) logger.info('start') train_begin_time = time.time() for epoch in range(start_epoch, num_epochs): logger.info('Epoch %d start' % epoch) train_queue = queue.Queue(self.num_workers << 1) for trainset in self.trainset_list: trainset.shuffle() # Training train_loader = MultiDataLoader(self.trainset_list, train_queue, batch_size, self.num_workers) train_loader.start() train_loss, train_cer = self.__train_epoches( model, epoch, epoch_time_step, train_begin_time, train_queue, teacher_forcing_ratio) train_loader.join() Checkpoint(model, self.optimizer, self.trainset_list, self.validset, epoch).save() logger.info('Epoch %d (Training) Loss %0.4f CER %0.4f' % (epoch, train_loss, train_cer)) teacher_forcing_ratio -= self.teacher_forcing_step teacher_forcing_ratio = max(self.min_teacher_forcing_ratio, teacher_forcing_ratio) # Validation valid_queue = queue.Queue(self.num_workers << 1) valid_loader = AudioDataLoader(self.validset, valid_queue, batch_size, 0) valid_loader.start() valid_loss, valid_cer = self.validate(model, valid_queue) valid_loader.join() logger.info('Epoch %d (Validate) Loss %0.4f CER %0.4f' % (epoch, valid_loss, valid_cer)) self.__save_epoch_result( train_result=[self.train_dict, train_loss, train_cer], valid_result=[self.valid_dict, valid_loss, valid_cer]) logger.info( 'Epoch %d Training result saved as a csv file complete !!' % epoch) torch.cuda.empty_cache() Checkpoint(model, self.optimizer, self.criterion, self.trainset_list, self.validset, num_epochs).save() return model
def train(self, model, batch_size, epoch_time_step, num_epochs, teacher_forcing_ratio=0.99, resume=False): """ Run training for a given model. Args: model (torch.nn.Module): model to train batch_size (int): batch size for experiment epoch_time_step (int): number of time step for training num_epochs (int): number of epochs for training teacher_forcing_ratio (float): teaching forcing ratio (default 0.99) resume(bool, optional): resume training with the latest checkpoint, (default False) """ start_epoch = 0 prev_train_cer = 1. if resume: checkpoint = Checkpoint() latest_checkpoint_path = checkpoint.get_latest_checkpoint() resume_checkpoint = checkpoint.load(latest_checkpoint_path) model = resume_checkpoint.model self.optimizer = resume_checkpoint.optimizer self.criterion = resume_checkpoint.criterion self.trainset_list = resume_checkpoint.trainset_list self.validset = resume_checkpoint.validset start_epoch = resume_checkpoint.epoch epoch_time_step = 0 for trainset in self.trainset_list: epoch_time_step += len(trainset) epoch_time_step = math.ceil(epoch_time_step / batch_size) epoch_time_step = 0 for trainset in self.trainset_list: epoch_time_step += len(trainset) epoch_time_step = math.ceil(epoch_time_step / batch_size) for g in self.optimizer.optimizer.param_groups: g['lr'] = 1e-04 print("Learning rate : %f", self.optimizer.get_lr()) logger.info('start') train_begin_time = time.time() for epoch in range(start_epoch, num_epochs): train_queue = queue.Queue(self.num_workers << 1) for trainset in self.trainset_list: trainset.shuffle() # Training train_loader = MultiAudioLoader(self.trainset_list, train_queue, batch_size, self.num_workers) train_loader.start() train_loss, train_cer = self.train_epoches(model, epoch, epoch_time_step, train_begin_time, train_queue, teacher_forcing_ratio) train_loader.join() Checkpoint(model, self.optimizer, self.criterion, self.trainset_list, self.validset, epoch).save() logger.info('Epoch %d (Training) Loss %0.4f CER %0.4f' % (epoch, train_loss, train_cer)) if prev_train_cer - train_cer < self.decay_threshold: self.optimizer.set_scheduler( ExponentialDecayLR(self.optimizer.optimizer, self.optimizer.get_lr(), self.low_plateau_lr, self.exp_decay_period), self.exp_decay_period) prev_train_cer = train_cer teacher_forcing_ratio -= self.teacher_forcing_step teacher_forcing_ratio = max(self.min_teacher_forcing_ratio, teacher_forcing_ratio) # Validation valid_queue = queue.Queue(self.num_workers << 1) valid_loader = AudioLoader(self.validset, valid_queue, batch_size, 0) valid_loader.start() valid_cer = self.validate(model, valid_queue) valid_loader.join() logger.info('Epoch %d (Validate) Loss %0.4f CER %0.4f' % (epoch, 0.0, valid_cer)) self._save_epoch_result( train_result=[self.train_dict, train_loss, train_cer], valid_result=[self.valid_dict, 0.0, valid_cer]) logger.info( 'Epoch %d Training result saved as a csv file complete !!' % epoch) return model