class _Progress: def __init__(self, show_limit: int = 50): from rich.progress import Progress, BarColumn, TimeRemainingColumn self._progress = Progress( "[progress.description]{task.description}", BarColumn(style="bar.back", complete_style="bar.complete", finished_style="bar.complete"), "[progress.percentage]{task.percentage:>3.0f}%", TimeRemainingColumn(), auto_refresh=False) self._last_update = _datetime.now() self._show_limit = show_limit self._completed = {} self._have_entered = False self._enter_args = None def __enter__(self, *args, **kwargs): self._enter_args = (args, kwargs) return self def __exit__(self, *args, **kwargs): if self._have_entered: self._progress.__exit__(*args, **kwargs) return False def add_task(self, description: str, total: int): if total > self._show_limit: if not self._have_entered: args, kwargs = self._enter_args self._progress.__enter__(*args, **kwargs) self._have_entered = True task_id = self._progress.add_task(description=description, total=total) self._completed[task_id] = 0 return task_id else: return None def update(self, task_id: int, completed: float = None, advance: float = None, force_update: bool = False): if task_id is None: return elif completed is not None: self._completed[task_id] = completed elif advance is not None: self._completed[task_id] += advance else: return now = _datetime.now() if force_update or (now - self._last_update).total_seconds() > 0.1: self._progress.update(task_id=task_id, completed=self._completed[task_id]) self._progress.refresh() self._last_update = now
def test_progress_max_refresh() -> None: """Test max_refresh argment.""" time = 0.0 def get_time() -> float: nonlocal time try: return time finally: time = time + 1.0 console = Console(color_system=None, width=80, legacy_windows=False, force_terminal=True) column = TextColumn("{task.description}") column.max_refresh = 3 progress = Progress( column, get_time=get_time, auto_refresh=False, console=console, ) console.begin_capture() with progress: task_id = progress.add_task("start") for tick in range(6): progress.update(task_id, description=f"tick {tick}") progress.refresh() result = console.end_capture() print(repr(result)) assert ( result == "\x1b[?25l\r\x1b[2Kstart\r\x1b[2Kstart\r\x1b[2Ktick 1\r\x1b[2Ktick 1\r\x1b[2Ktick 3\r\x1b[2Ktick 3\r\x1b[2Ktick 5\r\x1b[2Ktick 5\n\x1b[?25h" )
def test_expand_bar() -> None: console = Console(file=io.StringIO(), force_terminal=True, width=10) progress = Progress(BarColumn(bar_width=None), console=console) progress.add_task("foo") progress.refresh() expected = "\x1b[38;5;237m━━━━━━━━━\x1b[0m \r\x1b[2K\x1b[38;5;237m━━━━━━━━━\x1b[0m " assert console.file.getvalue() == expected
class ProgressHandler(BaseProgressHandler): def __init__(self): self.tracks = {} self.progress = Progress( TextColumn("[bold blue]{task.fields[title]}", justify="left"), BarColumn(bar_width=None), "[progress.percentage]{task.percentage:>3.2f}%", "•", DownloadColumn(), "•", TransferSpeedColumn(), "•", TimeRemainingColumn(), transient=True) def initialize(self, iterable, track_title, track_quality, total_size, chunk_size, **kwargs): track_id = kwargs["track_id"] task = self.progress.add_task(track_title, title=track_title, total=total_size) self.progress.console.print( f"[bold red]{track_title}[/] has started downloading.") self.tracks[track_id] = { "id": track_id, "iterable": iterable, "title": track_title, "quality": track_quality, "total_size": total_size, "chunk_size": chunk_size, "task": task, "size_downloaded": 0, "current_chunk_size": 0 } self.progress.start() def update(self, *args, **kwargs): track = self.tracks[kwargs["track_id"]] track["current_chunk_size"] = kwargs["current_chunk_size"] track["size_downloaded"] += track["current_chunk_size"] self.progress.update(track["task"], advance=track["current_chunk_size"]) def close(self, *args, **kwargs): track = self.tracks[kwargs["track_id"]] track_title = track["title"] self.progress.print( f"[bold red]{track_title}[/] is done downloading.") def close_progress(self): self.progress.refresh() self.progress.stop()
async def advance_progress(self, data: Dict, progress: Progress) -> None: """advances the progress bar for the given tasks by the current simulation progress""" simulation_id = data["data"]["simulation_id"] sim_progress = data["data"]["progress"] progress_task = self.tasks.get(int(simulation_id)) if not progress_task.task.started: progress.start_task(progress_task.task.id) # progress.console.print(f">>>>>>>>>> {sim_progress}") if int(sim_progress) <= progress_task.task.percentage: return advance_by = int(sim_progress) - int(progress_task.task.percentage) progress.update(progress_task.task.id, advance=advance_by) # auto_refresh is False so do this manually progress.refresh()
def test_expand_bar() -> None: console = Console(file=io.StringIO(), force_terminal=True, width=10, color_system="truecolor") progress = Progress( BarColumn(bar_width=None), console=console, get_time=lambda: 1.0, auto_refresh=False, ) progress.add_task("foo") progress.refresh() expected = "\x1b[38;5;237m━━━━━━━━━\x1b[0m \r\x1b[2K\x1b[38;5;237m━━━━━━━━━\x1b[0m " assert console.file.getvalue() == expected
def test_columns() -> None: console = Console( file=io.StringIO(), force_terminal=True, width=80, log_time_format="[TIME]", color_system="truecolor", legacy_windows=False, log_path=False, _environ={}, ) progress = Progress( "test", TextColumn("{task.description}"), BarColumn(bar_width=None), TimeRemainingColumn(), TimeElapsedColumn(), FileSizeColumn(), TotalFileSizeColumn(), DownloadColumn(), TransferSpeedColumn(), MofNCompleteColumn(), MofNCompleteColumn(separator=" of "), transient=True, console=console, auto_refresh=False, get_time=MockClock(), ) task1 = progress.add_task("foo", total=10) task2 = progress.add_task("bar", total=7) with progress: for n in range(4): progress.advance(task1, 3) progress.advance(task2, 4) print("foo") console.log("hello") console.print("world") progress.refresh() from .render import replace_link_ids result = replace_link_ids(console.file.getvalue()) print(repr(result)) expected = "\x1b[?25ltest foo \x1b[38;5;237m━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[33m0:00:07\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m0/10 bytes\x1b[0m \x1b[31m?\x1b[0m \x1b[32m 0/10\x1b[0m \x1b[32m 0 of 10\x1b[0m\ntest bar \x1b[38;5;237m━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[33m0:00:18\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m0/7 bytes \x1b[0m \x1b[31m?\x1b[0m \x1b[32m0/7 \x1b[0m \x1b[32m0 of 7 \x1b[0m\r\x1b[2K\x1b[1A\x1b[2Kfoo\ntest foo \x1b[38;5;237m━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[33m0:00:07\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m0/10 bytes\x1b[0m \x1b[31m?\x1b[0m \x1b[32m 0/10\x1b[0m \x1b[32m 0 of 10\x1b[0m\ntest bar \x1b[38;5;237m━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[33m0:00:18\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m0/7 bytes \x1b[0m \x1b[31m?\x1b[0m \x1b[32m0/7 \x1b[0m \x1b[32m0 of 7 \x1b[0m\r\x1b[2K\x1b[1A\x1b[2K\x1b[2;36m[TIME]\x1b[0m\x1b[2;36m \x1b[0mhello \ntest foo \x1b[38;5;237m━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[33m0:00:07\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m0/10 bytes\x1b[0m \x1b[31m?\x1b[0m \x1b[32m 0/10\x1b[0m \x1b[32m 0 of 10\x1b[0m\ntest bar \x1b[38;5;237m━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[33m0:00:18\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m0/7 bytes \x1b[0m \x1b[31m?\x1b[0m \x1b[32m0/7 \x1b[0m \x1b[32m0 of 7 \x1b[0m\r\x1b[2K\x1b[1A\x1b[2Kworld\ntest foo \x1b[38;5;237m━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[33m0:00:07\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m0/10 bytes\x1b[0m \x1b[31m?\x1b[0m \x1b[32m 0/10\x1b[0m \x1b[32m 0 of 10\x1b[0m\ntest bar \x1b[38;5;237m━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[33m0:00:18\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m0/7 bytes \x1b[0m \x1b[31m?\x1b[0m \x1b[32m0/7 \x1b[0m \x1b[32m0 of 7 \x1b[0m\r\x1b[2K\x1b[1A\x1b[2Ktest foo \x1b[38;2;114;156;31m━━━━━━━\x1b[0m \x1b[36m0:00:00\x1b[0m \x1b[33m0:00:34\x1b[0m \x1b[32m12 \x1b[0m \x1b[32m10 \x1b[0m \x1b[32m12/10 \x1b[0m \x1b[31m1 \x1b[0m \x1b[32m12/10\x1b[0m \x1b[32m12 of 10\x1b[0m\n \x1b[32mbytes \x1b[0m \x1b[32mbytes \x1b[0m \x1b[32mbytes \x1b[0m \x1b[31mbyte/s \x1b[0m \ntest bar \x1b[38;2;114;156;31m━━━━━━━\x1b[0m \x1b[36m0:00:00\x1b[0m \x1b[33m0:00:29\x1b[0m \x1b[32m16 \x1b[0m \x1b[32m7 bytes\x1b[0m \x1b[32m16/7 \x1b[0m \x1b[31m2 \x1b[0m \x1b[32m16/7 \x1b[0m \x1b[32m16 of 7 \x1b[0m\n \x1b[32mbytes \x1b[0m \x1b[32mbytes \x1b[0m \x1b[31mbytes/s\x1b[0m \r\x1b[2K\x1b[1A\x1b[2K\x1b[1A\x1b[2K\x1b[1A\x1b[2Ktest foo \x1b[38;2;114;156;31m━━━━━━━\x1b[0m \x1b[36m0:00:00\x1b[0m \x1b[33m0:00:34\x1b[0m \x1b[32m12 \x1b[0m \x1b[32m10 \x1b[0m \x1b[32m12/10 \x1b[0m \x1b[31m1 \x1b[0m \x1b[32m12/10\x1b[0m \x1b[32m12 of 10\x1b[0m\n \x1b[32mbytes \x1b[0m \x1b[32mbytes \x1b[0m \x1b[32mbytes \x1b[0m \x1b[31mbyte/s \x1b[0m \ntest bar \x1b[38;2;114;156;31m━━━━━━━\x1b[0m \x1b[36m0:00:00\x1b[0m \x1b[33m0:00:29\x1b[0m \x1b[32m16 \x1b[0m \x1b[32m7 bytes\x1b[0m \x1b[32m16/7 \x1b[0m \x1b[31m2 \x1b[0m \x1b[32m16/7 \x1b[0m \x1b[32m16 of 7 \x1b[0m\n \x1b[32mbytes \x1b[0m \x1b[32mbytes \x1b[0m \x1b[31mbytes/s\x1b[0m \n\x1b[?25h\r\x1b[1A\x1b[2K\x1b[1A\x1b[2K\x1b[1A\x1b[2K\x1b[1A\x1b[2K" assert result == expected
async def update_status(self, data: Dict, progress: Progress): status = data["data"]["status"] simulation_id = data["data"]["simulation_id"] progress_task = self.tasks.get(int(simulation_id)) if self.is_live_finished(status, progress_task): status = "finished" status_txt = status color = self.STATUS_COLORS.get(status) if color: status_txt = f"[bold {color}]{status_txt}[/bold {color}]" update_kwargs = {"status": status_txt} advance_by = 100 if status in ("finished", "ended", "postprocessing") else None if advance_by and not progress_task.task.started: progress.start_task(progress_task.task.id) update_kwargs.update({"advance": advance_by}) progress.update(progress_task.task.id, **update_kwargs) # auto_refresh is False so do this manually progress.refresh() progress_task.status = status
def test_columns() -> None: time = 1.0 def get_time(): nonlocal time time += 1.0 return time console = Console( file=io.StringIO(), force_terminal=True, width=80, log_time_format="[TIME]", color_system="truecolor", ) progress = Progress( "test", TextColumn("{task.description}"), BarColumn(bar_width=None), TimeRemainingColumn(), FileSizeColumn(), TotalFileSizeColumn(), DownloadColumn(), TransferSpeedColumn(), console=console, auto_refresh=False, get_time=get_time, ) task1 = progress.add_task("foo", total=10) task2 = progress.add_task("bar", total=7) with progress: for n in range(4): progress.advance(task1, 3) progress.advance(task2, 4) progress.log("hello") progress.print("world") progress.refresh() result = console.file.getvalue() print(repr(result)) expected = 'test foo \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m0/10 bytes\x1b[0m \x1b[31m?\x1b[0m \r\x1b[2Ktest foo \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m0/10 bytes\x1b[0m \x1b[31m?\x1b[0m \ntest bar \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m0/7 bytes \x1b[0m \x1b[31m?\x1b[0m \x1b[?25l\r\x1b[1A\x1b[2Ktest foo \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m0/10 bytes\x1b[0m \x1b[31m?\x1b[0m \ntest bar \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m0/7 bytes \x1b[0m \x1b[31m?\x1b[0m \r\x1b[1A\x1b[2K\x1b[2;36m[TIME]\x1b[0m\x1b[2;36m \x1b[0mhello \x1b[2mtest_progress.py:190\x1b[0m\x1b[2m \x1b[0m\ntest foo \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m0/10 bytes\x1b[0m \x1b[31m?\x1b[0m \ntest bar \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m0/7 bytes \x1b[0m \x1b[31m?\x1b[0m \r\x1b[1A\x1b[2Kworld\ntest foo \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m0/10 bytes\x1b[0m \x1b[31m?\x1b[0m \ntest bar \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m0/7 bytes \x1b[0m \x1b[31m?\x1b[0m \r\x1b[1A\x1b[2Ktest foo \x1b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m0:00:00\x1b[0m \x1b[32m12 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m12/10 bytes\x1b[0m \x1b[31m1 byte/s \x1b[0m \ntest bar \x1b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m0:00:00\x1b[0m \x1b[32m16 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m16/7 bytes \x1b[0m \x1b[31m2 bytes/s\x1b[0m \r\x1b[1A\x1b[2Ktest foo \x1b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m0:00:00\x1b[0m \x1b[32m12 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m12/10 bytes\x1b[0m \x1b[31m1 byte/s \x1b[0m \ntest bar \x1b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m0:00:00\x1b[0m \x1b[32m16 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m16/7 bytes \x1b[0m \x1b[31m2 bytes/s\x1b[0m \n\x1b[?25h' assert result == expected
def parsingYa(login): # Запись в txt if Ya == '4': file_txt = open(dirresults + "/results/Yandex_parser/" + str(hvostfile) + '_' + time.strftime("%d_%m_%Y_%H_%M_%S", time_data) + ".txt", "w", encoding="utf-8") # raise Exception("") else: file_txt = open(dirresults + "/results/Yandex_parser/" + str(login) + ".txt", "w", encoding="utf-8") progressYa = Progress("[progress.percentage]{task.percentage:>3.0f}%", auto_refresh=False) # Парсинг for login in progressYa.track(listlogin, description=""): urlYa = f'https://yandex.ru/collections/api/users/{login}/' try: r = my_session.get(urlYa, headers = head0, timeout=3) except: print(Fore.RED + "\nОшибка\n" + Style.RESET_ALL) continue try: rdict = json.loads(r.text) except: rdict = {} rdict.update(public_id="Увы", display_name="-No-") pub = rdict.get("public_id") name = rdict.get("display_name") email=str(login)+"@yandex.ru" if rdict.get("display_name") == "-No-": if Ya != '4': print(Style.BRIGHT + Fore.RED + "\nНе сработало") else: wZ1bad.append(str(login)) continue continue else: table1 = Table(title = "\n" + Style.BRIGHT + Fore.RED + str(login) + Style.RESET_ALL, style="green") table1.add_column("Имя", style="magenta") table1.add_column("Логин", style="cyan") table1.add_column("E-mail", style="cyan") if Ya == '3': table1.add_row(name,"Пропуск","Пропуск") else: table1.add_row(name,login,email) console = Console() console.print(table1) market=f"https://market.yandex.ru/user/{pub}/reviews" collections=f"https://yandex.ru/collections/user/{login}/" if Ya == '3': music=f"\033[33;1mПропуск\033[0m" else: music=f"https://music.yandex.ru/users/{login}/tracks" dzen=f"https://zen.yandex.ru/user/{pub}" qu=f"https://yandex.ru/q/profile/{pub}/" raion=f"https://local.yandex.ru/users/{pub}/" print("\033[32;1mЯ.Маркет:\033[0m", market) print("\033[32;1mЯ.Картинки:\033[0m", collections) print("\033[32;1mЯ.Музыка:\033[0m", music) print("\033[32;1mЯ.Дзен:\033[0m", dzen) print("\033[32;1mЯ.Кью:\033[0m", qu) print("\033[32;1mЯ.Район:\033[0m", raion) yalist=[market, collections, music, dzen, qu, raion] file_txt.write(f"{login} | {email} | {name}\n{market}\n{collections}\n{music}\n{dzen}\n{qu}\n{raion}\n\n\n",) progressYa.refresh() for webopen in yalist: if webopen == music and Ya == '3': continue else: webbrowser.open(webopen) if Ya == '4': # запись в txt концовка file_txt.write(f"\nНеобработанные данные из файла '{hvostfile}':\n") for badsites in wZ1bad: file_txt.write(f"{badsites}\n") file_txt.write(f"\nОбновлено: " + time.strftime("%d/%m/%Y_%H:%M:%S", time_data) + ".") file_txt.close()
def finish_task(progress: Progress, task_id): progress.update(task_id, completed=progress.tasks[task_id].total) progress.refresh()
def refresh(self, *args, **kwargs): if self._interactive: return Progress.refresh(self, *args, **kwargs)
class KrakenTrainProgressBar(ProgressBarBase): """ Adaptation of the default ptl rich progress bar to fit with kraken (segtrain, train) output. Args: refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. Set it to ``0`` to disable the display. leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False console_kwargs: Args for constructing a `Console` """ def __init__(self, refresh_rate: int = 1, leave: bool = True, console_kwargs: Optional[Dict[str, Any]] = None) -> None: super().__init__() self._refresh_rate: int = refresh_rate self._leave: bool = leave self._console_kwargs = console_kwargs or {} self._enabled: bool = True self.progress: Optional[Progress] = None self.val_sanity_progress_bar_id: Optional[int] = None self._reset_progress_bar_ids() self._metric_component = None self._progress_stopped: bool = False @property def refresh_rate(self) -> float: return self._refresh_rate @property def is_enabled(self) -> bool: return self._enabled and self.refresh_rate > 0 @property def is_disabled(self) -> bool: return not self.is_enabled def disable(self) -> None: self._enabled = False def enable(self) -> None: self._enabled = True @property def sanity_check_description(self) -> str: return "Validation Sanity Check" @property def validation_description(self) -> str: return "Validation" @property def test_description(self) -> str: return "Testing" def _init_progress(self, trainer): if self.is_enabled and (self.progress is None or self._progress_stopped): self._reset_progress_bar_ids() self._console = Console(**self._console_kwargs) self._console.clear_live() columns = self.configure_columns(trainer) self._metric_component = MetricsTextColumn(trainer) columns.append(self._metric_component) if trainer.early_stopping_callback: self._early_stopping_component = EarlyStoppingColumn(trainer) columns.append(self._early_stopping_component) self.progress = Progress(*columns, auto_refresh=False, disable=self.is_disabled, console=self._console) self.progress.start() # progress has started self._progress_stopped = False def refresh(self) -> None: if self.progress: self.progress.refresh() def on_train_start(self, trainer, pl_module): self._init_progress(trainer) def on_test_start(self, trainer, pl_module): self._init_progress(trainer) def on_validation_start(self, trainer, pl_module): self._init_progress(trainer) def on_sanity_check_start(self, trainer, pl_module): self._init_progress(trainer) def on_sanity_check_end(self, trainer, pl_module): if self.progress is not None: self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False) self.refresh() def on_train_epoch_start(self, trainer, pl_module): total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches if total_train_batches != float("inf"): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch total_batches = total_train_batches + total_val_batches train_description = f"stage {trainer.current_epoch}/{trainer.max_epochs if pl_module.hparams.quit == 'dumb' else '∞'}" if len(self.validation_description) > len(train_description): # Padding is required to avoid flickering due of uneven lengths of "Epoch X" # and "Validation" Bar description num_digits = len(str(trainer.current_epoch)) required_padding = (len(self.validation_description) - len(train_description) + 1) - num_digits for _ in range(required_padding): train_description += " " if self.main_progress_bar_id is not None and self._leave: self._stop_progress() self._init_progress(trainer) if self.main_progress_bar_id is None: self.main_progress_bar_id = self._add_task(total_batches, train_description) elif self.progress is not None: self.progress.reset(self.main_progress_bar_id, total=total_batches, description=train_description, visible=True) self.refresh() def on_validation_epoch_start(self, trainer, pl_module): if trainer.sanity_checking: self.val_sanity_progress_bar_id = self._add_task( self.total_val_batches, self.sanity_check_description) else: self.val_progress_bar_id = self._add_task( self.total_val_batches, self.validation_description, visible=False) self.refresh() def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]: if self.progress is not None: return self.progress.add_task(f"{description}", total=total_batches, visible=visible) def _update(self, progress_bar_id: int, current: int, total: Union[int, float], visible: bool = True) -> None: if self.progress is not None and self._should_update(current, total): leftover = current % self.refresh_rate advance = leftover if (current == total and leftover != 0) else self.refresh_rate self.progress.update(progress_bar_id, advance=advance, visible=visible) self.refresh() def _should_update(self, current: int, total: Union[int, float]) -> bool: return self.is_enabled and (current % self.refresh_rate == 0 or current == total) def on_validation_epoch_end(self, trainer, pl_module): if self.val_progress_bar_id is not None and trainer.state.fn == "fit": self.progress.update(self.val_progress_bar_id, advance=0, visible=False) self.refresh() def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if trainer.state.fn == "fit": self._update_metrics(trainer, pl_module) def on_test_epoch_start(self, trainer, pl_module): self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description) self.refresh() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches) self._update_metrics(trainer, pl_module) self.refresh() def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._update_metrics(trainer, pl_module) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): if trainer.sanity_checking: self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches) elif self.val_progress_bar_id is not None: # check to see if we should update the main training progress bar if self.main_progress_bar_id is not None: self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches) self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches) self.refresh() def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches) self.refresh() def _stop_progress(self) -> None: if self.progress is not None: self.progress.stop() # # signals for progress to be re-initialized for next stages self._progress_stopped = True def _reset_progress_bar_ids(self): self.main_progress_bar_id: Optional[int] = None self.val_progress_bar_id: Optional[int] = None self.test_progress_bar_id: Optional[int] = None def _update_metrics(self, trainer, pl_module) -> None: metrics = self.get_metrics(trainer, pl_module) metrics.pop('loss', None) metrics.pop('val_metric', None) if self._metric_component: self._metric_component.update(metrics) def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None: self._stop_progress() def on_exception(self, trainer, pl_module, exception: BaseException) -> None: self._stop_progress() @property def val_progress_bar(self) -> Task: return self.progress.tasks[self.val_progress_bar_id] @property def val_sanity_check_bar(self) -> Task: return self.progress.tasks[self.val_sanity_progress_bar_id] @property def main_progress_bar(self) -> Task: return self.progress.tasks[self.main_progress_bar_id] @property def test_progress_bar(self) -> Task: return self.progress.tasks[self.test_progress_bar_id] def configure_columns(self, trainer) -> list: return [ TextColumn("[progress.description]{task.description}"), BarColumn(), BatchesProcessedColumn(), TimeRemainingColumn(), TimeElapsedColumn() ]