Example #1
0
def test_reset() -> None:
    progress = Progress()
    task_id = progress.add_task("foo")
    progress.advance(task_id, 1)
    progress.advance(task_id, 1)
    progress.advance(task_id, 1)
    progress.advance(task_id, 7)
    task = progress.tasks[task_id]
    assert task.completed == 10
    progress.reset(
        task_id,
        total=200,
        completed=20,
        visible=False,
        description="bar",
        example="egg",
    )
    assert task.total == 200
    assert task.completed == 20
    assert task.visible == False
    assert task.description == "bar"
    assert task.fields == {"example": "egg"}
    assert not task._progress
Example #2
0
class TrafficProgress:
    def __init__(
        self,
        numRepos: int,
        follower: int = 0,
        following: int = 0,
        numStat: int = 5,
    ) -> None:

        self.numStat = numStat
        self.numRepos = numRepos
        self._profileText = Text(
            f"{follower:03d} Follower\n{following:03d} Following\n{numRepos:03d} Public Repositories"
        )

        self.progressTable = Table.grid(expand=True)
        self.progressTotal = Progress(
            "{task.description}",
            SpinnerColumn(),
            BarColumn(),
            TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
        )

        self.progressTable.add_row(
            Panel(
                Align.center(Text(
                    """Placeholder""",
                    justify="center",
                )),
                title="[b]Info",
                border_style="red",
                padding=(1, 1),
            ),
            Panel(
                Align.center(self._profileText),
                title="[b]Profile Info",
                border_style="yellow",
                padding=(1, 1),
            ),
            Panel(
                self.progressTotal,  # type:ignore
                title="[b]Total Progress",
                border_style="green",
                padding=(1, 2),
            ),
        )

        self.taskTotal = self.progressTotal.add_task(description="Progress",
                                                     total=numStat * numRepos)
        self.taskRepo = self.progressTotal.add_task(
            description="Repository [bold yellow]#", total=numRepos)
        self.taskStat = self.progressTotal.add_task(
            description="Stat [bold violet]#", total=numStat)

    def UpdateRepoDescription(self, repo: str):
        self.progressTotal.update(
            self.taskRepo, description=f"Repository [bold yellow]#{repo}")

    def UpdateStatDescription(self, stat: str):
        self.progressTotal.update(self.taskStat,
                                  description=f"Stat [bold violet]#{stat}")

    def StepTotal(self):
        self.progressTotal.advance(self.taskTotal)

    def StepRepo(self):
        self.progressTotal.advance(self.taskRepo)

    def StepStat(self):
        self.progressTotal.advance(self.taskStat)

    def ResetStatProgress(self):
        self.progressTotal.reset(self.taskStat)

    def CompleteStat(self):
        self.progressTotal.reset(
            self.taskStat,
            description="Stat [bold violet]#Completed",
            completed=self.numStat,
        )
Example #3
0
from rich.panel import Panel
from rich.progress import Progress


JOBS = [100, 150, 25, 70, 110, 90]

progress = Progress(auto_refresh=False)
master_task = progress.add_task("overall", total=sum(JOBS))
jobs_task = progress.add_task("jobs")

progress.console.print(
    Panel(
        "[bold blue]A demonstration of progress with a current task and overall progress.",
        padding=1,
    )
)

with progress:
    for job_no, job in enumerate(JOBS):
        progress.log(f"Starting job #{job_no}")
        sleep(0.2)
        progress.reset(jobs_task, total=job, description=f"job [bold yellow]#{job_no}")
        progress.start_task(jobs_task)
        for wait in progress.track(range(job), task_id=jobs_task):
            sleep(0.01)
        progress.advance(master_task, job)
        progress.log(f"Job #{job_no} is complete")
    progress.log(
        Panel(":sparkle: All done! :sparkle:", border_style="green", padding=1)
    )
Example #4
0
class tqdm_rich(std_tqdm):  # pragma: no cover
    """Experimental rich.progress GUI version of tqdm!"""

    # TODO: @classmethod: write()?
    def __init__(self, *args, **kwargs):
        """
        This class accepts the following parameters *in addition* to
        the parameters accepted by `tqdm`.

        Parameters
        ----------
        progress  : tuple, optional
            arguments for `rich.progress.Progress()`.
        """
        kwargs = kwargs.copy()
        kwargs['gui'] = True
        # convert disable = None to False
        kwargs['disable'] = bool(kwargs.get('disable', False))
        progress = kwargs.pop('progress', None)
        super(tqdm_rich, self).__init__(*args, **kwargs)

        if self.disable:
            return

        warn("rich is experimental/alpha",
             TqdmExperimentalWarning,
             stacklevel=2)
        d = self.format_dict
        if progress is None:
            progress = ("[progress.description]{task.description}"
                        "[progress.percentage]{task.percentage:>4.0f}%",
                        BarColumn(bar_width=None),
                        FractionColumn(unit_scale=d['unit_scale'],
                                       unit_divisor=d['unit_divisor']), "[",
                        TimeElapsedColumn(), "<", TimeRemainingColumn(), ",",
                        RateColumn(unit=d['unit'],
                                   unit_scale=d['unit_scale'],
                                   unit_divisor=d['unit_divisor']), "]")
        self._prog = Progress(*progress, transient=not self.leave)
        self._prog.__enter__()
        self._task_id = self._prog.add_task(self.desc or "", **d)

    def close(self, *args, **kwargs):
        if self.disable:
            return
        super(tqdm_rich, self).close(*args, **kwargs)
        self._prog.__exit__(None, None, None)

    def clear(self, *_, **__):
        pass

    def display(self, *_, **__):
        if not hasattr(self, '_prog'):
            return
        self._prog.update(self._task_id,
                          completed=self.n,
                          description=self.desc)

    def reset(self, total=None):
        """
        Resets to 0 iterations for repeated use.

        Parameters
        ----------
        total  : int or float, optional. Total to use for the new bar.
        """
        if hasattr(self, '_prog'):
            self._prog.reset(total=total)
        super(tqdm_rich, self).reset(total=total)
Example #5
0
class KrakenTrainProgressBar(ProgressBarBase):
    """
    Adaptation of the default ptl rich progress bar to fit with kraken (segtrain, train) output.

    Args:
        refresh_rate: Determines at which rate (in number of batches) the progress bars get updated.
            Set it to ``0`` to disable the display.
        leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False
        console_kwargs: Args for constructing a `Console`
    """
    def __init__(self,
                 refresh_rate: int = 1,
                 leave: bool = True,
                 console_kwargs: Optional[Dict[str, Any]] = None) -> None:
        super().__init__()
        self._refresh_rate: int = refresh_rate
        self._leave: bool = leave
        self._console_kwargs = console_kwargs or {}
        self._enabled: bool = True
        self.progress: Optional[Progress] = None
        self.val_sanity_progress_bar_id: Optional[int] = None
        self._reset_progress_bar_ids()
        self._metric_component = None
        self._progress_stopped: bool = False

    @property
    def refresh_rate(self) -> float:
        return self._refresh_rate

    @property
    def is_enabled(self) -> bool:
        return self._enabled and self.refresh_rate > 0

    @property
    def is_disabled(self) -> bool:
        return not self.is_enabled

    def disable(self) -> None:
        self._enabled = False

    def enable(self) -> None:
        self._enabled = True

    @property
    def sanity_check_description(self) -> str:
        return "Validation Sanity Check"

    @property
    def validation_description(self) -> str:
        return "Validation"

    @property
    def test_description(self) -> str:
        return "Testing"

    def _init_progress(self, trainer):
        if self.is_enabled and (self.progress is None
                                or self._progress_stopped):
            self._reset_progress_bar_ids()
            self._console = Console(**self._console_kwargs)
            self._console.clear_live()
            columns = self.configure_columns(trainer)
            self._metric_component = MetricsTextColumn(trainer)
            columns.append(self._metric_component)

            if trainer.early_stopping_callback:
                self._early_stopping_component = EarlyStoppingColumn(trainer)
                columns.append(self._early_stopping_component)

            self.progress = Progress(*columns,
                                     auto_refresh=False,
                                     disable=self.is_disabled,
                                     console=self._console)
            self.progress.start()
            # progress has started
            self._progress_stopped = False

    def refresh(self) -> None:
        if self.progress:
            self.progress.refresh()

    def on_train_start(self, trainer, pl_module):
        self._init_progress(trainer)

    def on_test_start(self, trainer, pl_module):
        self._init_progress(trainer)

    def on_validation_start(self, trainer, pl_module):
        self._init_progress(trainer)

    def on_sanity_check_start(self, trainer, pl_module):
        self._init_progress(trainer)

    def on_sanity_check_end(self, trainer, pl_module):
        if self.progress is not None:
            self.progress.update(self.val_sanity_progress_bar_id,
                                 advance=0,
                                 visible=False)
        self.refresh()

    def on_train_epoch_start(self, trainer, pl_module):
        total_train_batches = self.total_train_batches
        total_val_batches = self.total_val_batches
        if total_train_batches != float("inf"):
            # val can be checked multiple times per epoch
            val_checks_per_epoch = total_train_batches // trainer.val_check_batch
            total_val_batches = total_val_batches * val_checks_per_epoch

        total_batches = total_train_batches + total_val_batches

        train_description = f"stage {trainer.current_epoch}/{trainer.max_epochs if pl_module.hparams.quit == 'dumb' else '∞'}"
        if len(self.validation_description) > len(train_description):
            # Padding is required to avoid flickering due of uneven lengths of "Epoch X"
            # and "Validation" Bar description
            num_digits = len(str(trainer.current_epoch))
            required_padding = (len(self.validation_description) -
                                len(train_description) + 1) - num_digits
            for _ in range(required_padding):
                train_description += " "

        if self.main_progress_bar_id is not None and self._leave:
            self._stop_progress()
            self._init_progress(trainer)
        if self.main_progress_bar_id is None:
            self.main_progress_bar_id = self._add_task(total_batches,
                                                       train_description)
        elif self.progress is not None:
            self.progress.reset(self.main_progress_bar_id,
                                total=total_batches,
                                description=train_description,
                                visible=True)
        self.refresh()

    def on_validation_epoch_start(self, trainer, pl_module):
        if trainer.sanity_checking:
            self.val_sanity_progress_bar_id = self._add_task(
                self.total_val_batches, self.sanity_check_description)
        else:
            self.val_progress_bar_id = self._add_task(
                self.total_val_batches,
                self.validation_description,
                visible=False)
        self.refresh()

    def _add_task(self,
                  total_batches: int,
                  description: str,
                  visible: bool = True) -> Optional[int]:
        if self.progress is not None:
            return self.progress.add_task(f"{description}",
                                          total=total_batches,
                                          visible=visible)

    def _update(self,
                progress_bar_id: int,
                current: int,
                total: Union[int, float],
                visible: bool = True) -> None:
        if self.progress is not None and self._should_update(current, total):
            leftover = current % self.refresh_rate
            advance = leftover if (current == total
                                   and leftover != 0) else self.refresh_rate
            self.progress.update(progress_bar_id,
                                 advance=advance,
                                 visible=visible)
            self.refresh()

    def _should_update(self, current: int, total: Union[int, float]) -> bool:
        return self.is_enabled and (current % self.refresh_rate == 0
                                    or current == total)

    def on_validation_epoch_end(self, trainer, pl_module):
        if self.val_progress_bar_id is not None and trainer.state.fn == "fit":
            self.progress.update(self.val_progress_bar_id,
                                 advance=0,
                                 visible=False)
            self.refresh()

    def on_validation_end(self, trainer: "pl.Trainer",
                          pl_module: "pl.LightningModule") -> None:
        if trainer.state.fn == "fit":
            self._update_metrics(trainer, pl_module)

    def on_test_epoch_start(self, trainer, pl_module):
        self.test_progress_bar_id = self._add_task(self.total_test_batches,
                                                   self.test_description)
        self.refresh()

    def on_train_batch_end(self, trainer, pl_module, outputs, batch,
                           batch_idx):
        self._update(self.main_progress_bar_id, self.train_batch_idx,
                     self.total_train_batches)
        self._update_metrics(trainer, pl_module)
        self.refresh()

    def on_train_epoch_end(self, trainer: "pl.Trainer",
                           pl_module: "pl.LightningModule") -> None:
        self._update_metrics(trainer, pl_module)

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch,
                                batch_idx, dataloader_idx):
        if trainer.sanity_checking:
            self._update(self.val_sanity_progress_bar_id, self.val_batch_idx,
                         self.total_val_batches)
        elif self.val_progress_bar_id is not None:
            # check to see if we should update the main training progress bar
            if self.main_progress_bar_id is not None:
                self._update(self.main_progress_bar_id, self.val_batch_idx,
                             self.total_val_batches)
            self._update(self.val_progress_bar_id, self.val_batch_idx,
                         self.total_val_batches)
        self.refresh()

    def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
                          dataloader_idx):
        self._update(self.test_progress_bar_id, self.test_batch_idx,
                     self.total_test_batches)
        self.refresh()

    def _stop_progress(self) -> None:
        if self.progress is not None:
            self.progress.stop()
            # # signals for progress to be re-initialized for next stages
            self._progress_stopped = True

    def _reset_progress_bar_ids(self):
        self.main_progress_bar_id: Optional[int] = None
        self.val_progress_bar_id: Optional[int] = None
        self.test_progress_bar_id: Optional[int] = None

    def _update_metrics(self, trainer, pl_module) -> None:
        metrics = self.get_metrics(trainer, pl_module)
        metrics.pop('loss', None)
        metrics.pop('val_metric', None)
        if self._metric_component:
            self._metric_component.update(metrics)

    def teardown(self,
                 trainer,
                 pl_module,
                 stage: Optional[str] = None) -> None:
        self._stop_progress()

    def on_exception(self, trainer, pl_module,
                     exception: BaseException) -> None:
        self._stop_progress()

    @property
    def val_progress_bar(self) -> Task:
        return self.progress.tasks[self.val_progress_bar_id]

    @property
    def val_sanity_check_bar(self) -> Task:
        return self.progress.tasks[self.val_sanity_progress_bar_id]

    @property
    def main_progress_bar(self) -> Task:
        return self.progress.tasks[self.main_progress_bar_id]

    @property
    def test_progress_bar(self) -> Task:
        return self.progress.tasks[self.test_progress_bar_id]

    def configure_columns(self, trainer) -> list:
        return [
            TextColumn("[progress.description]{task.description}"),
            BarColumn(),
            BatchesProcessedColumn(),
            TimeRemainingColumn(),
            TimeElapsedColumn()
        ]