Esempio n. 1
0
    def _train_log(self, trainer: Trainer, pl_module: LightningModule):
        if self.training_config.evaluate_metrics:
            self.train_combined_report.metrics = pl_module.metrics(
                self.train_combined_report, self.train_combined_report)

        pl_module.train_meter.update_from_report(self.train_combined_report)

        extra = {}
        if "cuda" in str(trainer.model.device):
            extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
            extra["max mem"] //= 1024

        if self.training_config.experiment_name:
            extra["experiment"] = self.training_config.experiment_name

        optimizer = self.get_optimizer(trainer)
        num_updates = self._get_num_updates_for_logging(trainer)
        current_iteration = self._get_iterations_for_logging(trainer)
        extra.update({
            "epoch":
            self._get_current_epoch_for_logging(trainer),
            "iterations":
            current_iteration,
            "num_updates":
            num_updates,
            "max_updates":
            trainer.max_steps,
            "lr":
            "{:.5f}".format(optimizer.param_groups[0]["lr"]).rstrip("0"),
            "ups":
            "{:.2f}".format(self.trainer_config.log_every_n_steps /
                            self.train_timer.unix_time_since_start()),
            "time":
            self.train_timer.get_time_since_start(),
            "time_since_start":
            self.total_timer.get_time_since_start(),
            "eta":
            calculate_time_left(
                max_updates=trainer.max_steps,
                num_updates=num_updates,
                timer=self.train_timer,
                num_snapshot_iterations=self.snapshot_iterations,
                log_interval=self.trainer_config.log_every_n_steps,
                eval_interval=self.trainer_config.val_check_interval,
            ),
        })
        self.train_timer.reset()
        summarize_report(
            current_iteration=current_iteration,
            num_updates=num_updates,
            max_updates=trainer.max_steps,
            meter=pl_module.train_meter,
            extra=extra,
            tb_writer=self.lightning_trainer.tb_writer,
        )
Esempio n. 2
0
    def on_update_end(self, **kwargs):
        if not kwargs["should_log"]:
            return
        extra = {}
        if "cuda" in str(self.trainer.device):
            extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
            extra["max mem"] //= 1024

        if self.training_config.experiment_name:
            extra["experiment"] = self.training_config.experiment_name

        max_updates = getattr(self.trainer, "max_updates", None)
        num_updates = getattr(self.trainer, "num_updates", None)
        extra.update({
            "epoch":
            self.trainer.current_epoch,
            "num_updates":
            num_updates,
            "iterations":
            self.trainer.current_iteration,
            "max_updates":
            max_updates,
            "lr":
            "{:.5f}".format(
                self.trainer.optimizer.param_groups[0]["lr"]).rstrip("0"),
            "ups":
            "{:.2f}".format(self.log_interval /
                            self.train_timer.unix_time_since_start()),
            "time":
            self.train_timer.get_time_since_start(),
            "time_since_start":
            self.total_timer.get_time_since_start(),
            "eta":
            calculate_time_left(
                max_updates=max_updates,
                num_updates=num_updates,
                timer=self.train_timer,
                num_snapshot_iterations=self.snapshot_iterations,
                log_interval=self.log_interval,
                eval_interval=self.evaluation_interval,
            ),
        })
        self.train_timer.reset()
        summarize_report(
            current_iteration=self.trainer.current_iteration,
            num_updates=num_updates,
            max_updates=max_updates,
            meter=kwargs["meter"],
            extra=extra,
            tb_writer=self.tb_writer,
        )