Ejemplo n.º 1
0
    def _train_log(self, trainer: Trainer, pl_module: LightningModule):
        if self.training_config.evaluate_metrics:
            self.train_combined_report.metrics = pl_module.metrics(
                self.train_combined_report, self.train_combined_report)

        pl_module.train_meter.update_from_report(self.train_combined_report)

        extra = {}
        if "cuda" in str(trainer.model.device):
            extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
            extra["max mem"] //= 1024

        if self.training_config.experiment_name:
            extra["experiment"] = self.training_config.experiment_name

        optimizer = self.get_optimizer(trainer)
        num_updates = self._get_num_updates_for_logging(trainer)
        current_iteration = self._get_iterations_for_logging(trainer)
        extra.update({
            "epoch":
            self._get_current_epoch_for_logging(trainer),
            "iterations":
            current_iteration,
            "num_updates":
            num_updates,
            "max_updates":
            trainer.max_steps,
            "lr":
            "{:.5f}".format(optimizer.param_groups[0]["lr"]).rstrip("0"),
            "ups":
            "{:.2f}".format(self.trainer_config.log_every_n_steps /
                            self.train_timer.unix_time_since_start()),
            "time":
            self.train_timer.get_time_since_start(),
            "time_since_start":
            self.total_timer.get_time_since_start(),
            "eta":
            calculate_time_left(
                max_updates=trainer.max_steps,
                num_updates=num_updates,
                timer=self.train_timer,
                num_snapshot_iterations=self.snapshot_iterations,
                log_interval=self.trainer_config.log_every_n_steps,
                eval_interval=self.trainer_config.val_check_interval,
            ),
        })
        self.train_timer.reset()
        summarize_report(
            current_iteration=current_iteration,
            num_updates=num_updates,
            max_updates=trainer.max_steps,
            meter=pl_module.train_meter,
            extra=extra,
            tb_writer=self.lightning_trainer.tb_writer,
        )
Ejemplo n.º 2
0
 def on_test_end(self, **kwargs):
     prefix = "{}: full {}".format(kwargs["report"].dataset_name,
                                   kwargs["report"].dataset_type)
     summarize_report(
         current_iteration=self.trainer.current_iteration,
         num_updates=self.trainer.num_updates,
         max_updates=self.trainer.max_updates,
         meter=kwargs["meter"],
         should_print=prefix,
         tb_writer=self.tb_writer,
     )
     logger.info(
         f"Finished run in {self.total_timer.get_time_since_start()}")
Ejemplo n.º 3
0
    def on_update_end(self, **kwargs):
        if not kwargs["should_log"]:
            return
        extra = {}
        if "cuda" in str(self.trainer.device):
            extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
            extra["max mem"] //= 1024

        if self.training_config.experiment_name:
            extra["experiment"] = self.training_config.experiment_name

        max_updates = getattr(self.trainer, "max_updates", None)
        num_updates = getattr(self.trainer, "num_updates", None)
        extra.update({
            "epoch":
            self.trainer.current_epoch,
            "num_updates":
            num_updates,
            "iterations":
            self.trainer.current_iteration,
            "max_updates":
            max_updates,
            "lr":
            "{:.5f}".format(
                self.trainer.optimizer.param_groups[0]["lr"]).rstrip("0"),
            "ups":
            "{:.2f}".format(self.log_interval /
                            self.train_timer.unix_time_since_start()),
            "time":
            self.train_timer.get_time_since_start(),
            "time_since_start":
            self.total_timer.get_time_since_start(),
            "eta":
            calculate_time_left(
                max_updates=max_updates,
                num_updates=num_updates,
                timer=self.train_timer,
                num_snapshot_iterations=self.snapshot_iterations,
                log_interval=self.log_interval,
                eval_interval=self.evaluation_interval,
            ),
        })
        self.train_timer.reset()
        summarize_report(
            current_iteration=self.trainer.current_iteration,
            num_updates=num_updates,
            max_updates=max_updates,
            meter=kwargs["meter"],
            extra=extra,
            tb_writer=self.tb_writer,
        )
Ejemplo n.º 4
0
 def on_validation_end(self, **kwargs):
     extra = {
         "num_updates": self.trainer.num_updates,
         "epoch": self.trainer.current_epoch,
         "iterations": self.trainer.current_iteration,
         "max_updates": self.trainer.max_updates,
         "val_time": self.snapshot_timer.get_time_since_start(),
     }
     extra.update(
         self.trainer.early_stop_callback.early_stopping.get_info())
     self.train_timer.reset()
     summarize_report(
         current_iteration=self.trainer.current_iteration,
         num_updates=self.trainer.num_updates,
         max_updates=self.trainer.max_updates,
         meter=kwargs["meter"],
         extra=extra,
         tb_writer=self.tb_writer,
     )
Ejemplo n.º 5
0
 def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
     iterations = self._get_iterations_for_logging(trainer)
     current_epochs = self._get_current_epoch_for_logging(trainer)
     num_updates = self._get_num_updates_for_logging(trainer)
     extra = {
         "num_updates": num_updates,
         "epoch": current_epochs,
         "iterations": iterations,
         "max_updates": trainer.max_steps,
         "val_time": self.snapshot_timer.get_time_since_start(),
     }
     # TODO: @sash populate early stop info for logging (next mvp)
     # extra.update(self.trainer.early_stop_callback.early_stopping.get_info())
     self.train_timer.reset()
     summarize_report(
         current_iteration=iterations,
         num_updates=num_updates,
         max_updates=trainer.max_steps,
         meter=pl_module.val_meter,
         extra=extra,
         tb_writer=self.lightning_trainer.tb_writer,
     )