def train(opt): random.seed(opt.seed) torch.manual_seed(opt.seed) torch.cuda.manual_seed_all(opt.seed) device = check_envirionment(opt.use_cuda) audio_paths, script_paths = load_data_list(opt.data_list_path, opt.dataset_path) epoch_time_step, trainset_list, validset = split_dataset( opt, audio_paths, script_paths) model = build_ensemble(['model_path1', 'model_path2', 'model_path3'], opt.ensemble_method, device) optimizer = optim.Adam(model.module.parameters(), lr=opt.init_lr) optimizer = Optimizer(optimizer, None, 0, opt.max_grad_norm) criterion = nn.NLLLoss(reduction='sum', ignore_index=PAD_token).to(device) trainer = SupervisedTrainer( optimizer=optimizer, criterion=criterion, trainset_list=trainset_list, validset=validset, num_workers=opt.num_workers, high_plateau_lr=opt.high_plateau_lr, low_plateau_lr=opt.low_plateau_lr, decay_threshold=opt.decay_threshold, exp_decay_period=opt.exp_decay_period, device=device, teacher_forcing_step=opt.teacher_forcing_step, min_teacher_forcing_ratio=opt.min_teacher_forcing_ratio, print_every=opt.print_every, save_result_every=opt.save_result_every, checkpoint_every=opt.checkpoint_every) model = trainer.train(model=model, batch_size=opt.batch_size, epoch_time_step=epoch_time_step, num_epochs=opt.num_epochs, teacher_forcing_ratio=opt.teacher_forcing_ratio, resume=opt.resume) Checkpoint(model, model.optimizer, model.criterion, model.trainset_list, model.validset, opt.num_epochs).save()
def train(self, model: nn.Module, batch_size: int, epoch_time_step: int, num_epochs: int, teacher_forcing_ratio: float = 0.99, resume: bool = False) -> nn.Module: """ Run training for a given model. Args: model (torch.nn.Module): model to train batch_size (int): batch size for experiment epoch_time_step (int): number of time step for training num_epochs (int): number of epochs for training teacher_forcing_ratio (float): teaching forcing ratio (default 0.99) resume(bool, optional): resume training with the latest checkpoint, (default False) """ start_epoch = 0 if resume: checkpoint = Checkpoint() latest_checkpoint_path = checkpoint.get_latest_checkpoint() resume_checkpoint = checkpoint.load(latest_checkpoint_path) model = resume_checkpoint.model self.optimizer = resume_checkpoint.optimizer self.trainset_list = resume_checkpoint.trainset_list self.validset = resume_checkpoint.validset start_epoch = resume_checkpoint.epoch + 1 epoch_time_step = 0 for trainset in self.trainset_list: epoch_time_step += len(trainset) epoch_time_step = math.ceil(epoch_time_step / batch_size) logger.info('start') train_begin_time = time.time() for epoch in range(start_epoch, num_epochs): logger.info('Epoch %d start' % epoch) train_queue = queue.Queue(self.num_workers << 1) for trainset in self.trainset_list: trainset.shuffle() # Training train_loader = MultiDataLoader(self.trainset_list, train_queue, batch_size, self.num_workers) train_loader.start() train_loss, train_cer = self.__train_epoches( model, epoch, epoch_time_step, train_begin_time, train_queue, teacher_forcing_ratio) train_loader.join() Checkpoint(model, self.optimizer, self.trainset_list, self.validset, epoch).save() logger.info('Epoch %d (Training) Loss %0.4f CER %0.4f' % (epoch, train_loss, train_cer)) teacher_forcing_ratio -= self.teacher_forcing_step teacher_forcing_ratio = max(self.min_teacher_forcing_ratio, teacher_forcing_ratio) # Validation valid_queue = queue.Queue(self.num_workers << 1) valid_loader = AudioDataLoader(self.validset, valid_queue, batch_size, 0) valid_loader.start() valid_loss, valid_cer = self.validate(model, valid_queue) valid_loader.join() logger.info('Epoch %d (Validate) Loss %0.4f CER %0.4f' % (epoch, valid_loss, valid_cer)) self.__save_epoch_result( train_result=[self.train_dict, train_loss, train_cer], valid_result=[self.valid_dict, valid_loss, valid_cer]) logger.info( 'Epoch %d Training result saved as a csv file complete !!' % epoch) torch.cuda.empty_cache() Checkpoint(model, self.optimizer, self.criterion, self.trainset_list, self.validset, num_epochs).save() return model
def __train_epoches(self, model: nn.Module, epoch: int, epoch_time_step: int, train_begin_time: float, queue: queue.Queue, teacher_forcing_ratio: float) -> Tuple[float, float]: """ Run training one epoch Args: model (torch.nn.Module): model to train epoch (int): number of current epoch epoch_time_step (int): total time step in one epoch train_begin_time (float): time of train begin queue (queue.Queue): training queue, containing input, targets, input_lengths, target_lengths teacher_forcing_ratio (float): teaching forcing ratio (default 0.99) Returns: loss, cer - **loss** (float): loss of current epoch - **cer** (float): character error rate of current epoch """ cer = 1.0 epoch_loss_total = 0. total_num = 0 timestep = 0 model.train() begin_time = epoch_begin_time = time.time() num_workers = self.num_workers while True: inputs, targets, input_lengths, target_lengths = queue.get() if inputs.shape[0] == 0: # Empty feats means closing one loader num_workers -= 1 logger.debug('left train_loader: %d' % num_workers) if num_workers == 0: break else: continue inputs = inputs.to(self.device) targets = targets.to(self.device) model = model.to(self.device) if self.architecture == 'las': if isinstance(model, nn.DataParallel): model.module.flatten_parameters() else: model.flatten_parameters() logit = model(inputs=inputs, input_lengths=input_lengths, targets=targets, teacher_forcing_ratio=teacher_forcing_ratio) logit = torch.stack(logit, dim=1).to(self.device) targets = targets[:, 1:] elif self.architecture == 'transformer': logit = model(inputs, input_lengths, targets, return_attns=False) else: raise ValueError("Unsupported architecture : {0}".format( self.architecture)) hypothesis = logit.max(-1)[1] loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), targets.contiguous().view(-1)) epoch_loss_total += loss.item() cer = self.metric(targets, hypothesis) total_num += int(input_lengths.sum()) self.optimizer.zero_grad() loss.backward() self.optimizer.step(model) timestep += 1 torch.cuda.empty_cache() if timestep % self.print_every == 0: current_time = time.time() elapsed = current_time - begin_time epoch_elapsed = (current_time - epoch_begin_time) / 60.0 train_elapsed = (current_time - train_begin_time) / 3600.0 logger.info( 'timestep: {:4d}/{:4d}, loss: {:.4f}, cer: {:.2f}, elapsed: {:.2f}s {:.2f}m {:.2f}h, lr: {:.5f}' .format(timestep, epoch_time_step, epoch_loss_total / total_num, cer, elapsed, epoch_elapsed, train_elapsed, self.optimizer.get_lr())) begin_time = time.time() if timestep % self.save_result_every == 0: self.__save_step_result(self.train_step_result, epoch_loss_total / total_num, cer) if timestep % self.checkpoint_every == 0: Checkpoint(model, self.optimizer, self.trainset_list, self.validset, epoch).save() del inputs, input_lengths, targets, logit, loss, hypothesis Checkpoint(model, self.optimizer, self.trainset_list, self.validset, epoch).save() logger.info('train() completed') return epoch_loss_total / total_num, cer
def train(opt): random.seed(opt.seed) torch.manual_seed(opt.seed) torch.cuda.manual_seed_all(opt.seed) device = check_envirionment(opt.use_cuda) if not opt.resume: audio_paths, script_paths = load_data_list(opt.data_list_path, opt.dataset_path) epoch_time_step, trainset_list, validset = split_dataset( opt, audio_paths, script_paths) model = build_model(opt, device) optimizer = optim.Adam(model.module.parameters(), lr=opt.init_lr, weight_decay=1e-05) if opt.rampup_period > 0: scheduler = RampUpLR(optimizer, opt.init_lr, opt.high_plateau_lr, opt.rampup_period) optimizer = Optimizer(optimizer, scheduler, opt.rampup_period, opt.max_grad_norm) else: optimizer = Optimizer(optimizer, None, 0, opt.max_grad_norm) if opt.label_smoothing == 0.0: criterion = nn.NLLLoss(reduction='sum', ignore_index=PAD_token).to(device) else: criterion = LabelSmoothingLoss(len(char2id), PAD_token, opt.label_smoothing, dim=-1).to(device) else: trainset_list = None validset = None model = None optimizer = None criterion = None epoch_time_step = None trainer = SupervisedTrainer( optimizer=optimizer, criterion=criterion, trainset_list=trainset_list, validset=validset, num_workers=opt.num_workers, high_plateau_lr=opt.high_plateau_lr, low_plateau_lr=opt.low_plateau_lr, decay_threshold=opt.decay_threshold, exp_decay_period=opt.exp_decay_period, device=device, teacher_forcing_step=opt.teacher_forcing_step, min_teacher_forcing_ratio=opt.min_teacher_forcing_ratio, print_every=opt.print_every, save_result_every=opt.save_result_every, checkpoint_every=opt.checkpoint_every) model = trainer.train(model=model, batch_size=opt.batch_size, epoch_time_step=epoch_time_step, num_epochs=opt.num_epochs, teacher_forcing_ratio=opt.teacher_forcing_ratio, resume=opt.resume) Checkpoint(model, model.optimizer, model.criterion, model.trainset_list, model.validset, opt.num_epochs).save()
def train(self, model, batch_size, epoch_time_step, num_epochs, teacher_forcing_ratio=0.99, resume=False): """ Run training for a given model. Args: model (torch.nn.Module): model to train batch_size (int): batch size for experiment epoch_time_step (int): number of time step for training num_epochs (int): number of epochs for training teacher_forcing_ratio (float): teaching forcing ratio (default 0.99) resume(bool, optional): resume training with the latest checkpoint, (default False) """ start_epoch = 0 prev_train_cer = 1. if resume: checkpoint = Checkpoint() latest_checkpoint_path = checkpoint.get_latest_checkpoint() resume_checkpoint = checkpoint.load(latest_checkpoint_path) model = resume_checkpoint.model self.optimizer = resume_checkpoint.optimizer self.criterion = resume_checkpoint.criterion self.trainset_list = resume_checkpoint.trainset_list self.validset = resume_checkpoint.validset start_epoch = resume_checkpoint.epoch epoch_time_step = 0 for trainset in self.trainset_list: epoch_time_step += len(trainset) epoch_time_step = math.ceil(epoch_time_step / batch_size) epoch_time_step = 0 for trainset in self.trainset_list: epoch_time_step += len(trainset) epoch_time_step = math.ceil(epoch_time_step / batch_size) for g in self.optimizer.optimizer.param_groups: g['lr'] = 1e-04 print("Learning rate : %f", self.optimizer.get_lr()) logger.info('start') train_begin_time = time.time() for epoch in range(start_epoch, num_epochs): train_queue = queue.Queue(self.num_workers << 1) for trainset in self.trainset_list: trainset.shuffle() # Training train_loader = MultiAudioLoader(self.trainset_list, train_queue, batch_size, self.num_workers) train_loader.start() train_loss, train_cer = self.train_epoches(model, epoch, epoch_time_step, train_begin_time, train_queue, teacher_forcing_ratio) train_loader.join() Checkpoint(model, self.optimizer, self.criterion, self.trainset_list, self.validset, epoch).save() logger.info('Epoch %d (Training) Loss %0.4f CER %0.4f' % (epoch, train_loss, train_cer)) if prev_train_cer - train_cer < self.decay_threshold: self.optimizer.set_scheduler( ExponentialDecayLR(self.optimizer.optimizer, self.optimizer.get_lr(), self.low_plateau_lr, self.exp_decay_period), self.exp_decay_period) prev_train_cer = train_cer teacher_forcing_ratio -= self.teacher_forcing_step teacher_forcing_ratio = max(self.min_teacher_forcing_ratio, teacher_forcing_ratio) # Validation valid_queue = queue.Queue(self.num_workers << 1) valid_loader = AudioLoader(self.validset, valid_queue, batch_size, 0) valid_loader.start() valid_cer = self.validate(model, valid_queue) valid_loader.join() logger.info('Epoch %d (Validate) Loss %0.4f CER %0.4f' % (epoch, 0.0, valid_cer)) self._save_epoch_result( train_result=[self.train_dict, train_loss, train_cer], valid_result=[self.valid_dict, 0.0, valid_cer]) logger.info( 'Epoch %d Training result saved as a csv file complete !!' % epoch) return model
def train_epoches(self, model, epoch, epoch_time_step, train_begin_time, queue, teacher_forcing_ratio): """ Run training one epoch Args: model (torch.nn.Module): model to train epoch (int): number of current epoch epoch_time_step (int): total time step in one epoch train_begin_time (int): time of train begin queue (queue.Queue): training queue, containing input, targets, input_lengths, target_lengths teacher_forcing_ratio (float): teaching forcing ratio (default 0.99) Returns: loss, cer - **loss** (float): loss of current epoch - **cer** (float): character error rate of current epoch """ cer = 1.0 epoch_loss_total = 0. total_num = 0 timestep = 0 model.train() begin_time = epoch_begin_time = time.time() while True: inputs, scripts, input_lengths, target_lengths = queue.get() if inputs.shape[0] == 0: # Empty feats means closing one loader self.num_workers -= 1 logger.debug('left train_loader: %d' % self.num_workers) if self.num_workers == 0: break else: continue inputs = inputs.to(self.device) scripts = scripts.to(self.device) targets = scripts[:, 1:] model.module.flatten_parameters() output = model(inputs, input_lengths, scripts, teacher_forcing_ratio=teacher_forcing_ratio)[0] logit = torch.stack(output, dim=1).to(self.device) hypothesis = logit.max(-1)[1] loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), targets.contiguous().view(-1)) epoch_loss_total += loss.item() cer = self.metric(targets, hypothesis) total_num += int(input_lengths.sum()) self.optimizer.zero_grad() loss.backward() self.optimizer.step(model, loss.item()) timestep += 1 torch.cuda.empty_cache() if timestep % self.print_every == 0: current_time = time.time() elapsed = current_time - begin_time epoch_elapsed = (current_time - epoch_begin_time) / 60.0 train_elapsed = (current_time - train_begin_time) / 3600.0 logger.info( 'timestep: {:4d}/{:4d}, loss: {:.4f}, cer: {:.2f}, elapsed: {:.2f}s {:.2f}m {:.2f}h' .format(timestep, epoch_time_step, epoch_loss_total / total_num, cer, elapsed, epoch_elapsed, train_elapsed)) begin_time = time.time() if timestep % self.save_result_every == 0: self._save_step_result(self.train_step_result, epoch_loss_total / total_num, cer) if timestep % self.checkpoint_every == 0: Checkpoint(model, self.optimizer, self.criterion, self.trainset_list, self.validset, epoch).save() del inputs, input_lengths, scripts, targets, output, logit, loss, hypothesis Checkpoint(model, self.optimizer, self.criterion, self.trainset_list, self.validset, epoch).save() logger.info('train() completed') return epoch_loss_total / total_num, cer