def _train_log(self, trainer: Trainer, pl_module: LightningModule): if self.training_config.evaluate_metrics: self.train_combined_report.metrics = pl_module.metrics( self.train_combined_report, self.train_combined_report) pl_module.train_meter.update_from_report(self.train_combined_report) extra = {} if "cuda" in str(trainer.model.device): extra["max mem"] = torch.cuda.max_memory_allocated() / 1024 extra["max mem"] //= 1024 if self.training_config.experiment_name: extra["experiment"] = self.training_config.experiment_name optimizer = self.get_optimizer(trainer) num_updates = self._get_num_updates_for_logging(trainer) current_iteration = self._get_iterations_for_logging(trainer) extra.update({ "epoch": self._get_current_epoch_for_logging(trainer), "iterations": current_iteration, "num_updates": num_updates, "max_updates": trainer.max_steps, "lr": "{:.5f}".format(optimizer.param_groups[0]["lr"]).rstrip("0"), "ups": "{:.2f}".format(self.trainer_config.log_every_n_steps / self.train_timer.unix_time_since_start()), "time": self.train_timer.get_time_since_start(), "time_since_start": self.total_timer.get_time_since_start(), "eta": calculate_time_left( max_updates=trainer.max_steps, num_updates=num_updates, timer=self.train_timer, num_snapshot_iterations=self.snapshot_iterations, log_interval=self.trainer_config.log_every_n_steps, eval_interval=self.trainer_config.val_check_interval, ), }) self.train_timer.reset() summarize_report( current_iteration=current_iteration, num_updates=num_updates, max_updates=trainer.max_steps, meter=pl_module.train_meter, extra=extra, tb_writer=self.lightning_trainer.tb_writer, )
def on_test_end(self, **kwargs): prefix = "{}: full {}".format(kwargs["report"].dataset_name, kwargs["report"].dataset_type) summarize_report( current_iteration=self.trainer.current_iteration, num_updates=self.trainer.num_updates, max_updates=self.trainer.max_updates, meter=kwargs["meter"], should_print=prefix, tb_writer=self.tb_writer, ) logger.info( f"Finished run in {self.total_timer.get_time_since_start()}")
def on_update_end(self, **kwargs): if not kwargs["should_log"]: return extra = {} if "cuda" in str(self.trainer.device): extra["max mem"] = torch.cuda.max_memory_allocated() / 1024 extra["max mem"] //= 1024 if self.training_config.experiment_name: extra["experiment"] = self.training_config.experiment_name max_updates = getattr(self.trainer, "max_updates", None) num_updates = getattr(self.trainer, "num_updates", None) extra.update({ "epoch": self.trainer.current_epoch, "num_updates": num_updates, "iterations": self.trainer.current_iteration, "max_updates": max_updates, "lr": "{:.5f}".format( self.trainer.optimizer.param_groups[0]["lr"]).rstrip("0"), "ups": "{:.2f}".format(self.log_interval / self.train_timer.unix_time_since_start()), "time": self.train_timer.get_time_since_start(), "time_since_start": self.total_timer.get_time_since_start(), "eta": calculate_time_left( max_updates=max_updates, num_updates=num_updates, timer=self.train_timer, num_snapshot_iterations=self.snapshot_iterations, log_interval=self.log_interval, eval_interval=self.evaluation_interval, ), }) self.train_timer.reset() summarize_report( current_iteration=self.trainer.current_iteration, num_updates=num_updates, max_updates=max_updates, meter=kwargs["meter"], extra=extra, tb_writer=self.tb_writer, )
def on_validation_end(self, **kwargs): extra = { "num_updates": self.trainer.num_updates, "epoch": self.trainer.current_epoch, "iterations": self.trainer.current_iteration, "max_updates": self.trainer.max_updates, "val_time": self.snapshot_timer.get_time_since_start(), } extra.update( self.trainer.early_stop_callback.early_stopping.get_info()) self.train_timer.reset() summarize_report( current_iteration=self.trainer.current_iteration, num_updates=self.trainer.num_updates, max_updates=self.trainer.max_updates, meter=kwargs["meter"], extra=extra, tb_writer=self.tb_writer, )
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule): iterations = self._get_iterations_for_logging(trainer) current_epochs = self._get_current_epoch_for_logging(trainer) num_updates = self._get_num_updates_for_logging(trainer) extra = { "num_updates": num_updates, "epoch": current_epochs, "iterations": iterations, "max_updates": trainer.max_steps, "val_time": self.snapshot_timer.get_time_since_start(), } # TODO: @sash populate early stop info for logging (next mvp) # extra.update(self.trainer.early_stop_callback.early_stopping.get_info()) self.train_timer.reset() summarize_report( current_iteration=iterations, num_updates=num_updates, max_updates=trainer.max_steps, meter=pl_module.val_meter, extra=extra, tb_writer=self.lightning_trainer.tb_writer, )