class RichProgressBar(ProgressBarBase): """Create a progress bar with `rich text formatting <https://github.com/willmcgugan/rich>`_. Install it with pip: .. code-block:: bash pip install rich .. code-block:: python from pytorch_lightning import Trainer from pytorch_lightning.callbacks import RichProgressBar trainer = Trainer(callbacks=RichProgressBar()) Args: refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. Set it to ``0`` to disable the display. leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False theme: Contains styles used to stylize the progress bar. console_kwargs: Args for constructing a `Console` Raises: ModuleNotFoundError: If required `rich` package is not installed on the device. Note: PyCharm users will need to enable “emulate terminal” in output console option in run/debug configuration to see styled output. Reference: https://rich.readthedocs.io/en/latest/introduction.html#requirements """ def __init__( self, refresh_rate: int = 1, leave: bool = False, theme: RichProgressBarTheme = RichProgressBarTheme(), console_kwargs: Optional[Dict[str, Any]] = None, ) -> None: if not _RICH_AVAILABLE: raise ModuleNotFoundError( "`RichProgressBar` requires `rich` >= 10.2.2. Install it by running `pip install -U rich`." ) super().__init__() self._refresh_rate: int = refresh_rate self._leave: bool = leave self._console_kwargs = console_kwargs or {} self._enabled: bool = True self.progress: Optional[Progress] = None self.val_sanity_progress_bar_id: Optional[int] = None self._reset_progress_bar_ids() self._metric_component = None self._progress_stopped: bool = False self.theme = theme self._update_for_light_colab_theme() @property def refresh_rate(self) -> float: return self._refresh_rate @property def is_enabled(self) -> bool: return self._enabled and self.refresh_rate > 0 @property def is_disabled(self) -> bool: return not self.is_enabled def _update_for_light_colab_theme(self) -> None: if _detect_light_colab_theme(): attributes = ["description", "batch_progress", "metrics"] for attr in attributes: if getattr(self.theme, attr) == "white": setattr(self.theme, attr, "black") def disable(self) -> None: self._enabled = False def enable(self) -> None: self._enabled = True def _init_progress(self, trainer): if self.is_enabled and (self.progress is None or self._progress_stopped): self._reset_progress_bar_ids() self._console = Console(**self._console_kwargs) self._console.clear_live() self._metric_component = MetricsTextColumn(trainer, self.theme.metrics) self.progress = CustomProgress( *self.configure_columns(trainer), self._metric_component, auto_refresh=False, disable=self.is_disabled, console=self._console, ) self.progress.start() # progress has started self._progress_stopped = False def refresh(self) -> None: if self.progress: self.progress.refresh() def on_train_start(self, trainer, pl_module): self._init_progress(trainer) def on_predict_start(self, trainer, pl_module): self._init_progress(trainer) def on_test_start(self, trainer, pl_module): self._init_progress(trainer) def on_validation_start(self, trainer, pl_module): self._init_progress(trainer) def on_sanity_check_start(self, trainer, pl_module): self._init_progress(trainer) def on_sanity_check_end(self, trainer, pl_module): if self.progress is not None: self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False) self.refresh() def on_train_epoch_start(self, trainer, pl_module): total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches if total_train_batches != float("inf"): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch total_batches = total_train_batches + total_val_batches train_description = self._get_train_description(trainer.current_epoch) if self.main_progress_bar_id is not None and self._leave: self._stop_progress() self._init_progress(trainer) if self.main_progress_bar_id is None: self.main_progress_bar_id = self._add_task(total_batches, train_description) elif self.progress is not None: self.progress.reset(self.main_progress_bar_id, total=total_batches, description=train_description, visible=True) self.refresh() def on_validation_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int) -> None: if not self.has_dataloader_changed(dataloader_idx): return if trainer.sanity_checking: if self.val_sanity_progress_bar_id is not None: self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False) self.val_sanity_progress_bar_id = self._add_task( self.total_val_batches_current_dataloader, self.sanity_check_description, visible=False) else: if self.val_progress_bar_id is not None: self.progress.update(self.val_progress_bar_id, advance=0, visible=False) # TODO: remove old tasks when new onces are created self.val_progress_bar_id = self._add_task( self.total_val_batches_current_dataloader, self.validation_description, visible=False) self.refresh() def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]: if self.progress is not None: return self.progress.add_task( f"[{self.theme.description}]{description}", total=total_batches, visible=visible) def _update(self, progress_bar_id: int, current: int, total: Union[int, float], visible: bool = True) -> None: if self.progress is not None and self._should_update(current, total): leftover = current % self.refresh_rate advance = leftover if (current == total and leftover != 0) else self.refresh_rate self.progress.update(progress_bar_id, advance=advance, visible=visible) self.refresh() def _should_update(self, current: int, total: Union[int, float]) -> bool: return self.is_enabled and (current % self.refresh_rate == 0 or current == total) def on_validation_epoch_end(self, trainer, pl_module): if self.val_progress_bar_id is not None and trainer.state.fn == "fit": self.progress.update(self.val_progress_bar_id, advance=0, visible=False) self.refresh() def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if trainer.state.fn == "fit": self._update_metrics(trainer, pl_module) self.reset_dataloader_idx_tracker() def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.reset_dataloader_idx_tracker() def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.reset_dataloader_idx_tracker() def on_test_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int) -> None: if not self.has_dataloader_changed(dataloader_idx): return if self.test_progress_bar_id is not None: self.progress.update(self.test_progress_bar_id, advance=0, visible=False) self.test_progress_bar_id = self._add_task( self.total_test_batches_current_dataloader, self.test_description) self.refresh() def on_predict_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int) -> None: if not self.has_dataloader_changed(dataloader_idx): return if self.predict_progress_bar_id is not None: self.progress.update(self.predict_progress_bar_id, advance=0, visible=False) self.predict_progress_bar_id = self._add_task( self.total_predict_batches_current_dataloader, self.predict_description) self.refresh() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches) self._update_metrics(trainer, pl_module) self.refresh() def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._update_metrics(trainer, pl_module) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): if trainer.sanity_checking: self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader) elif self.val_progress_bar_id is not None: # check to see if we should update the main training progress bar if self.main_progress_bar_id is not None: # TODO: Use total val_processed here just like TQDM in a follow-up self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader) self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader) self.refresh() def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches_current_dataloader) self.refresh() def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._update(self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches_current_dataloader) self.refresh() def _get_train_description(self, current_epoch: int) -> str: train_description = f"Epoch {current_epoch}" if len(self.validation_description) > len(train_description): # Padding is required to avoid flickering due of uneven lengths of "Epoch X" # and "Validation" Bar description num_digits = len(str(current_epoch)) required_padding = (len(self.validation_description) - len(train_description) + 1) - num_digits for _ in range(required_padding): train_description += " " return train_description def _stop_progress(self) -> None: if self.progress is not None: self.progress.stop() # # signals for progress to be re-initialized for next stages self._progress_stopped = True def _reset_progress_bar_ids(self): self.main_progress_bar_id: Optional[int] = None self.val_progress_bar_id: Optional[int] = None self.test_progress_bar_id: Optional[int] = None self.predict_progress_bar_id: Optional[int] = None def _update_metrics(self, trainer, pl_module) -> None: metrics = self.get_metrics(trainer, pl_module) if self._metric_component: self._metric_component.update(metrics) def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None: self._stop_progress() def on_exception(self, trainer, pl_module, exception: BaseException) -> None: self._stop_progress() @property def val_progress_bar(self) -> Task: return self.progress.tasks[self.val_progress_bar_id] @property def val_sanity_check_bar(self) -> Task: return self.progress.tasks[self.val_sanity_progress_bar_id] @property def main_progress_bar(self) -> Task: return self.progress.tasks[self.main_progress_bar_id] @property def test_progress_bar(self) -> Task: return self.progress.tasks[self.test_progress_bar_id] def configure_columns(self, trainer) -> list: return [ TextColumn("[progress.description]{task.description}"), CustomBarColumn( complete_style=self.theme.progress_bar, finished_style=self.theme.progress_bar_finished, pulse_style=self.theme.progress_bar_pulse, ), BatchesProcessedColumn(style=self.theme.batch_progress), CustomTimeColumn(style=self.theme.time), ProcessingSpeedColumn(style=self.theme.processing_speed), ]
class KrakenTrainProgressBar(ProgressBarBase): """ Adaptation of the default ptl rich progress bar to fit with kraken (segtrain, train) output. Args: refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. Set it to ``0`` to disable the display. leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False console_kwargs: Args for constructing a `Console` """ def __init__(self, refresh_rate: int = 1, leave: bool = True, console_kwargs: Optional[Dict[str, Any]] = None) -> None: super().__init__() self._refresh_rate: int = refresh_rate self._leave: bool = leave self._console_kwargs = console_kwargs or {} self._enabled: bool = True self.progress: Optional[Progress] = None self.val_sanity_progress_bar_id: Optional[int] = None self._reset_progress_bar_ids() self._metric_component = None self._progress_stopped: bool = False @property def refresh_rate(self) -> float: return self._refresh_rate @property def is_enabled(self) -> bool: return self._enabled and self.refresh_rate > 0 @property def is_disabled(self) -> bool: return not self.is_enabled def disable(self) -> None: self._enabled = False def enable(self) -> None: self._enabled = True @property def sanity_check_description(self) -> str: return "Validation Sanity Check" @property def validation_description(self) -> str: return "Validation" @property def test_description(self) -> str: return "Testing" def _init_progress(self, trainer): if self.is_enabled and (self.progress is None or self._progress_stopped): self._reset_progress_bar_ids() self._console = Console(**self._console_kwargs) self._console.clear_live() columns = self.configure_columns(trainer) self._metric_component = MetricsTextColumn(trainer) columns.append(self._metric_component) if trainer.early_stopping_callback: self._early_stopping_component = EarlyStoppingColumn(trainer) columns.append(self._early_stopping_component) self.progress = Progress(*columns, auto_refresh=False, disable=self.is_disabled, console=self._console) self.progress.start() # progress has started self._progress_stopped = False def refresh(self) -> None: if self.progress: self.progress.refresh() def on_train_start(self, trainer, pl_module): self._init_progress(trainer) def on_test_start(self, trainer, pl_module): self._init_progress(trainer) def on_validation_start(self, trainer, pl_module): self._init_progress(trainer) def on_sanity_check_start(self, trainer, pl_module): self._init_progress(trainer) def on_sanity_check_end(self, trainer, pl_module): if self.progress is not None: self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False) self.refresh() def on_train_epoch_start(self, trainer, pl_module): total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches if total_train_batches != float("inf"): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch total_batches = total_train_batches + total_val_batches train_description = f"stage {trainer.current_epoch}/{trainer.max_epochs if pl_module.hparams.quit == 'dumb' else '∞'}" if len(self.validation_description) > len(train_description): # Padding is required to avoid flickering due of uneven lengths of "Epoch X" # and "Validation" Bar description num_digits = len(str(trainer.current_epoch)) required_padding = (len(self.validation_description) - len(train_description) + 1) - num_digits for _ in range(required_padding): train_description += " " if self.main_progress_bar_id is not None and self._leave: self._stop_progress() self._init_progress(trainer) if self.main_progress_bar_id is None: self.main_progress_bar_id = self._add_task(total_batches, train_description) elif self.progress is not None: self.progress.reset(self.main_progress_bar_id, total=total_batches, description=train_description, visible=True) self.refresh() def on_validation_epoch_start(self, trainer, pl_module): if trainer.sanity_checking: self.val_sanity_progress_bar_id = self._add_task( self.total_val_batches, self.sanity_check_description) else: self.val_progress_bar_id = self._add_task( self.total_val_batches, self.validation_description, visible=False) self.refresh() def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]: if self.progress is not None: return self.progress.add_task(f"{description}", total=total_batches, visible=visible) def _update(self, progress_bar_id: int, current: int, total: Union[int, float], visible: bool = True) -> None: if self.progress is not None and self._should_update(current, total): leftover = current % self.refresh_rate advance = leftover if (current == total and leftover != 0) else self.refresh_rate self.progress.update(progress_bar_id, advance=advance, visible=visible) self.refresh() def _should_update(self, current: int, total: Union[int, float]) -> bool: return self.is_enabled and (current % self.refresh_rate == 0 or current == total) def on_validation_epoch_end(self, trainer, pl_module): if self.val_progress_bar_id is not None and trainer.state.fn == "fit": self.progress.update(self.val_progress_bar_id, advance=0, visible=False) self.refresh() def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if trainer.state.fn == "fit": self._update_metrics(trainer, pl_module) def on_test_epoch_start(self, trainer, pl_module): self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description) self.refresh() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches) self._update_metrics(trainer, pl_module) self.refresh() def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._update_metrics(trainer, pl_module) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): if trainer.sanity_checking: self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches) elif self.val_progress_bar_id is not None: # check to see if we should update the main training progress bar if self.main_progress_bar_id is not None: self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches) self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches) self.refresh() def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches) self.refresh() def _stop_progress(self) -> None: if self.progress is not None: self.progress.stop() # # signals for progress to be re-initialized for next stages self._progress_stopped = True def _reset_progress_bar_ids(self): self.main_progress_bar_id: Optional[int] = None self.val_progress_bar_id: Optional[int] = None self.test_progress_bar_id: Optional[int] = None def _update_metrics(self, trainer, pl_module) -> None: metrics = self.get_metrics(trainer, pl_module) metrics.pop('loss', None) metrics.pop('val_metric', None) if self._metric_component: self._metric_component.update(metrics) def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None: self._stop_progress() def on_exception(self, trainer, pl_module, exception: BaseException) -> None: self._stop_progress() @property def val_progress_bar(self) -> Task: return self.progress.tasks[self.val_progress_bar_id] @property def val_sanity_check_bar(self) -> Task: return self.progress.tasks[self.val_sanity_progress_bar_id] @property def main_progress_bar(self) -> Task: return self.progress.tasks[self.main_progress_bar_id] @property def test_progress_bar(self) -> Task: return self.progress.tasks[self.test_progress_bar_id] def configure_columns(self, trainer) -> list: return [ TextColumn("[progress.description]{task.description}"), BarColumn(), BatchesProcessedColumn(), TimeRemainingColumn(), TimeElapsedColumn() ]