def execute_training_loop(self, trn, dev, devices, epochs, logger, patience, save_dir, optimizer, gradient_accumulation, **kwargs): optimizer, scheduler, transformer_optimizer, transformer_scheduler = optimizer criterion = self.build_criterion() best_e, best_metric = 0, self.build_metric() timer = CountdownTimer(epochs) history = History() ratio_width = len( f'{len(trn) // gradient_accumulation}/{len(trn) // gradient_accumulation}' ) for epoch in range(1, epochs + 1): # train one epoch and update the parameters logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]") self.fit_dataloader(trn, optimizer, scheduler, criterion, epoch, logger, history, transformer_optimizer, transformer_scheduler, gradient_accumulation=gradient_accumulation) loss, dev_metric = self.evaluate_dataloader( dev, criterion, ratio_width=ratio_width, logger=logger) timer.update() # logger.info(f"{'Dev' + ' ' * ratio_width} loss: {loss:.4f} {dev_metric}") # save the model if it is the best so far report = f"{timer.elapsed_human} / {timer.total_time_human} ETA: {timer.eta_human}" if dev_metric > best_metric: best_e, best_metric = epoch, dev_metric self.save_weights(save_dir) report += ' ([red]saved[/red])' else: if patience != epochs: report += f' ({epoch - best_e}/{patience})' else: report += f' ({epoch - best_e})' logger.info(report) if patience is not None and epoch - best_e >= patience: logger.info( f'LAS has stopped improving for {patience} epochs, early stop.' ) break timer.stop() if not best_e: self.save_weights(save_dir) elif best_e != epoch: self.load_weights(save_dir) logger.info( f"Max score of dev is {best_metric.score:.2%} at epoch {best_e}") logger.info( f"Average time of each epoch is {timer.elapsed_average_human}") logger.info(f"{timer.elapsed_human} elapsed")
def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir, logger: logging.Logger, devices, ratio_width=None, patience=0.5, eval_trn=True, **kwargs): if isinstance(patience, float): patience = int(patience * epochs) best_epoch, best_metric = 0, -1 timer = CountdownTimer(epochs) history = History() for epoch in range(1, epochs + 1): logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]") self.fit_dataloader(trn, criterion, optimizer, metric, logger, history=history, ratio_width=ratio_width, eval_trn=eval_trn, **self.config) loss, dev_metric = self.evaluate_dataloader( dev, criterion, logger=logger, ratio_width=ratio_width) timer.update() report = f"{timer.elapsed_human} / {timer.total_time_human} ETA: {timer.eta_human}" if dev_metric > best_metric: best_epoch, best_metric = epoch, dev_metric self.save_weights(save_dir) report += ' [red](saved)[/red]' else: report += f' ({epoch - best_epoch})' if epoch - best_epoch >= patience: report += ' early stop' logger.info(report) if epoch - best_epoch >= patience: break if not best_epoch: self.save_weights(save_dir) elif best_epoch != epoch: self.load_weights(save_dir) logger.info(f"Max score of dev is {best_metric} at epoch {best_epoch}") logger.info( f"Average time of each epoch is {timer.elapsed_average_human}") logger.info(f"{timer.elapsed_human} elapsed")
def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir, logger: logging.Logger, devices, ratio_width=None, dev_data=None, eval_after=None, **kwargs): best_epoch, best_metric = 0, -1 if isinstance(eval_after, float): eval_after = int(epochs * eval_after) timer = CountdownTimer(epochs) history = History() for epoch in range(1, epochs + 1): logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]") self.fit_dataloader(trn, criterion, optimizer, metric, logger, history=history, ratio_width=ratio_width, **self.config) if epoch > eval_after: dev_metric = self.evaluate_dataloader(dev, criterion, logger=logger, ratio_width=ratio_width, output=os.path.join( save_dir, 'dev.pred.txt'), input=dev_data, use_fast=True) timer.update() report = f"{timer.elapsed_human} / {timer.total_time_human} ETA: {timer.eta_human}" if epoch > eval_after: if dev_metric > best_metric: best_epoch, best_metric = epoch, dev_metric self.save_weights(save_dir) report += ' [red](saved)[/red]' else: report += f' ({epoch - best_epoch})' # if epoch - best_epoch >= patience: # report += ' early stop' logger.info(report) # if epoch - best_epoch >= patience: # break if not best_epoch: self.save_weights(save_dir) elif best_epoch != epoch: self.load_weights(save_dir) logger.info(f"Max score of dev is {best_metric} at epoch {best_epoch}") logger.info( f"Average time of each epoch is {timer.elapsed_average_human}") logger.info(f"{timer.elapsed_human} elapsed") return best_metric