def test_reset() -> None: progress = Progress() task_id = progress.add_task("foo") progress.advance(task_id, 1) progress.advance(task_id, 1) progress.advance(task_id, 1) progress.advance(task_id, 7) task = progress.tasks[task_id] assert task.completed == 10 progress.reset( task_id, total=200, completed=20, visible=False, description="bar", example="egg", ) assert task.total == 200 assert task.completed == 20 assert task.visible == False assert task.description == "bar" assert task.fields == {"example": "egg"} assert not task._progress
class TrafficProgress: def __init__( self, numRepos: int, follower: int = 0, following: int = 0, numStat: int = 5, ) -> None: self.numStat = numStat self.numRepos = numRepos self._profileText = Text( f"{follower:03d} Follower\n{following:03d} Following\n{numRepos:03d} Public Repositories" ) self.progressTable = Table.grid(expand=True) self.progressTotal = Progress( "{task.description}", SpinnerColumn(), BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), ) self.progressTable.add_row( Panel( Align.center(Text( """Placeholder""", justify="center", )), title="[b]Info", border_style="red", padding=(1, 1), ), Panel( Align.center(self._profileText), title="[b]Profile Info", border_style="yellow", padding=(1, 1), ), Panel( self.progressTotal, # type:ignore title="[b]Total Progress", border_style="green", padding=(1, 2), ), ) self.taskTotal = self.progressTotal.add_task(description="Progress", total=numStat * numRepos) self.taskRepo = self.progressTotal.add_task( description="Repository [bold yellow]#", total=numRepos) self.taskStat = self.progressTotal.add_task( description="Stat [bold violet]#", total=numStat) def UpdateRepoDescription(self, repo: str): self.progressTotal.update( self.taskRepo, description=f"Repository [bold yellow]#{repo}") def UpdateStatDescription(self, stat: str): self.progressTotal.update(self.taskStat, description=f"Stat [bold violet]#{stat}") def StepTotal(self): self.progressTotal.advance(self.taskTotal) def StepRepo(self): self.progressTotal.advance(self.taskRepo) def StepStat(self): self.progressTotal.advance(self.taskStat) def ResetStatProgress(self): self.progressTotal.reset(self.taskStat) def CompleteStat(self): self.progressTotal.reset( self.taskStat, description="Stat [bold violet]#Completed", completed=self.numStat, )
from rich.panel import Panel from rich.progress import Progress JOBS = [100, 150, 25, 70, 110, 90] progress = Progress(auto_refresh=False) master_task = progress.add_task("overall", total=sum(JOBS)) jobs_task = progress.add_task("jobs") progress.console.print( Panel( "[bold blue]A demonstration of progress with a current task and overall progress.", padding=1, ) ) with progress: for job_no, job in enumerate(JOBS): progress.log(f"Starting job #{job_no}") sleep(0.2) progress.reset(jobs_task, total=job, description=f"job [bold yellow]#{job_no}") progress.start_task(jobs_task) for wait in progress.track(range(job), task_id=jobs_task): sleep(0.01) progress.advance(master_task, job) progress.log(f"Job #{job_no} is complete") progress.log( Panel(":sparkle: All done! :sparkle:", border_style="green", padding=1) )
class tqdm_rich(std_tqdm): # pragma: no cover """Experimental rich.progress GUI version of tqdm!""" # TODO: @classmethod: write()? def __init__(self, *args, **kwargs): """ This class accepts the following parameters *in addition* to the parameters accepted by `tqdm`. Parameters ---------- progress : tuple, optional arguments for `rich.progress.Progress()`. """ kwargs = kwargs.copy() kwargs['gui'] = True # convert disable = None to False kwargs['disable'] = bool(kwargs.get('disable', False)) progress = kwargs.pop('progress', None) super(tqdm_rich, self).__init__(*args, **kwargs) if self.disable: return warn("rich is experimental/alpha", TqdmExperimentalWarning, stacklevel=2) d = self.format_dict if progress is None: progress = ("[progress.description]{task.description}" "[progress.percentage]{task.percentage:>4.0f}%", BarColumn(bar_width=None), FractionColumn(unit_scale=d['unit_scale'], unit_divisor=d['unit_divisor']), "[", TimeElapsedColumn(), "<", TimeRemainingColumn(), ",", RateColumn(unit=d['unit'], unit_scale=d['unit_scale'], unit_divisor=d['unit_divisor']), "]") self._prog = Progress(*progress, transient=not self.leave) self._prog.__enter__() self._task_id = self._prog.add_task(self.desc or "", **d) def close(self, *args, **kwargs): if self.disable: return super(tqdm_rich, self).close(*args, **kwargs) self._prog.__exit__(None, None, None) def clear(self, *_, **__): pass def display(self, *_, **__): if not hasattr(self, '_prog'): return self._prog.update(self._task_id, completed=self.n, description=self.desc) def reset(self, total=None): """ Resets to 0 iterations for repeated use. Parameters ---------- total : int or float, optional. Total to use for the new bar. """ if hasattr(self, '_prog'): self._prog.reset(total=total) super(tqdm_rich, self).reset(total=total)
class KrakenTrainProgressBar(ProgressBarBase): """ Adaptation of the default ptl rich progress bar to fit with kraken (segtrain, train) output. Args: refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. Set it to ``0`` to disable the display. leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False console_kwargs: Args for constructing a `Console` """ def __init__(self, refresh_rate: int = 1, leave: bool = True, console_kwargs: Optional[Dict[str, Any]] = None) -> None: super().__init__() self._refresh_rate: int = refresh_rate self._leave: bool = leave self._console_kwargs = console_kwargs or {} self._enabled: bool = True self.progress: Optional[Progress] = None self.val_sanity_progress_bar_id: Optional[int] = None self._reset_progress_bar_ids() self._metric_component = None self._progress_stopped: bool = False @property def refresh_rate(self) -> float: return self._refresh_rate @property def is_enabled(self) -> bool: return self._enabled and self.refresh_rate > 0 @property def is_disabled(self) -> bool: return not self.is_enabled def disable(self) -> None: self._enabled = False def enable(self) -> None: self._enabled = True @property def sanity_check_description(self) -> str: return "Validation Sanity Check" @property def validation_description(self) -> str: return "Validation" @property def test_description(self) -> str: return "Testing" def _init_progress(self, trainer): if self.is_enabled and (self.progress is None or self._progress_stopped): self._reset_progress_bar_ids() self._console = Console(**self._console_kwargs) self._console.clear_live() columns = self.configure_columns(trainer) self._metric_component = MetricsTextColumn(trainer) columns.append(self._metric_component) if trainer.early_stopping_callback: self._early_stopping_component = EarlyStoppingColumn(trainer) columns.append(self._early_stopping_component) self.progress = Progress(*columns, auto_refresh=False, disable=self.is_disabled, console=self._console) self.progress.start() # progress has started self._progress_stopped = False def refresh(self) -> None: if self.progress: self.progress.refresh() def on_train_start(self, trainer, pl_module): self._init_progress(trainer) def on_test_start(self, trainer, pl_module): self._init_progress(trainer) def on_validation_start(self, trainer, pl_module): self._init_progress(trainer) def on_sanity_check_start(self, trainer, pl_module): self._init_progress(trainer) def on_sanity_check_end(self, trainer, pl_module): if self.progress is not None: self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False) self.refresh() def on_train_epoch_start(self, trainer, pl_module): total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches if total_train_batches != float("inf"): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch total_batches = total_train_batches + total_val_batches train_description = f"stage {trainer.current_epoch}/{trainer.max_epochs if pl_module.hparams.quit == 'dumb' else '∞'}" if len(self.validation_description) > len(train_description): # Padding is required to avoid flickering due of uneven lengths of "Epoch X" # and "Validation" Bar description num_digits = len(str(trainer.current_epoch)) required_padding = (len(self.validation_description) - len(train_description) + 1) - num_digits for _ in range(required_padding): train_description += " " if self.main_progress_bar_id is not None and self._leave: self._stop_progress() self._init_progress(trainer) if self.main_progress_bar_id is None: self.main_progress_bar_id = self._add_task(total_batches, train_description) elif self.progress is not None: self.progress.reset(self.main_progress_bar_id, total=total_batches, description=train_description, visible=True) self.refresh() def on_validation_epoch_start(self, trainer, pl_module): if trainer.sanity_checking: self.val_sanity_progress_bar_id = self._add_task( self.total_val_batches, self.sanity_check_description) else: self.val_progress_bar_id = self._add_task( self.total_val_batches, self.validation_description, visible=False) self.refresh() def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]: if self.progress is not None: return self.progress.add_task(f"{description}", total=total_batches, visible=visible) def _update(self, progress_bar_id: int, current: int, total: Union[int, float], visible: bool = True) -> None: if self.progress is not None and self._should_update(current, total): leftover = current % self.refresh_rate advance = leftover if (current == total and leftover != 0) else self.refresh_rate self.progress.update(progress_bar_id, advance=advance, visible=visible) self.refresh() def _should_update(self, current: int, total: Union[int, float]) -> bool: return self.is_enabled and (current % self.refresh_rate == 0 or current == total) def on_validation_epoch_end(self, trainer, pl_module): if self.val_progress_bar_id is not None and trainer.state.fn == "fit": self.progress.update(self.val_progress_bar_id, advance=0, visible=False) self.refresh() def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if trainer.state.fn == "fit": self._update_metrics(trainer, pl_module) def on_test_epoch_start(self, trainer, pl_module): self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description) self.refresh() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches) self._update_metrics(trainer, pl_module) self.refresh() def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._update_metrics(trainer, pl_module) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): if trainer.sanity_checking: self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches) elif self.val_progress_bar_id is not None: # check to see if we should update the main training progress bar if self.main_progress_bar_id is not None: self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches) self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches) self.refresh() def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches) self.refresh() def _stop_progress(self) -> None: if self.progress is not None: self.progress.stop() # # signals for progress to be re-initialized for next stages self._progress_stopped = True def _reset_progress_bar_ids(self): self.main_progress_bar_id: Optional[int] = None self.val_progress_bar_id: Optional[int] = None self.test_progress_bar_id: Optional[int] = None def _update_metrics(self, trainer, pl_module) -> None: metrics = self.get_metrics(trainer, pl_module) metrics.pop('loss', None) metrics.pop('val_metric', None) if self._metric_component: self._metric_component.update(metrics) def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None: self._stop_progress() def on_exception(self, trainer, pl_module, exception: BaseException) -> None: self._stop_progress() @property def val_progress_bar(self) -> Task: return self.progress.tasks[self.val_progress_bar_id] @property def val_sanity_check_bar(self) -> Task: return self.progress.tasks[self.val_sanity_progress_bar_id] @property def main_progress_bar(self) -> Task: return self.progress.tasks[self.main_progress_bar_id] @property def test_progress_bar(self) -> Task: return self.progress.tasks[self.test_progress_bar_id] def configure_columns(self, trainer) -> list: return [ TextColumn("[progress.description]{task.description}"), BarColumn(), BatchesProcessedColumn(), TimeRemainingColumn(), TimeElapsedColumn() ]