Пример #1
0
class RichTablePrinter(object):
    def __init__(self, fields={}, key=None):
        """
        Logger based on `rich` tables

        :param key: str or None
            main key to group results by row
        :param fields: dict of (dict or False)
            Field descriptors containing goal ("lower_is_better" or "higher_is_better"), format and display name
            The key is a regex that will be used to match the fields to log
        """

        self.fields = dict(fields)
        self.key = key
        self.key_to_row_idx = {}
        self.name_to_column_idx = {}
        self.table = None
        self.console = None
        self.live = None
        self.best = {}
        if key is not None and key not in self.fields:
            self.fields = {key: {}, **fields}

    def _repr_html_(self) -> str:
        if self.console is None:
            return "Empty table"
        segments = list(self.console.render(self.table, self.console.options))  # type: ignore
        html = _render_segments(segments)
        return html

    def log(self, info):
        if self.table is None:
            is_in_notebook = check_is_in_notebook()
            self.table = Table()
            self.console = Console(width=2 ** 32 - 1 if is_in_notebook else None)

            if is_in_notebook:
                dh = display(None, display_id=True)
                self.refresh = lambda: dh.update(self)
            else:
                self.live = Live(self.table, console=self.console, auto_refresh=False)
                self.live.start()
                self.refresh = lambda: self.live.refresh()
            # self.console = Console()
            # table_centered = Columns((self.table,), align="center", expand=True)
            # self.live = Live(table_centered, console=console)
            # self.live.start()

        for name, value in info.items():
            if name not in self.name_to_column_idx:
                matcher, column_name = get_last_matching_value(self.fields, name, "name", default=name)
                if column_name is False:
                    self.name_to_column_idx[name] = -1
                    continue
                self.table.add_column(re.sub(matcher, column_name, name) if matcher is not None else name, no_wrap=True)
                self.table.columns[-1]._cells = [''] * (len(self.table.columns[0]._cells) if len(self.table.columns) else 0)
                self.name_to_column_idx[name] = (max(self.name_to_column_idx.values()) + 1) if len(self.name_to_column_idx) else 0
        new_name_to_column_idx = {}
        columns = []

        def get_name_index(name):
            try:
                return get_last_matching_index(self.fields, name)
            except ValueError:
                return len(self.name_to_column_idx)

        for name in sorted(self.name_to_column_idx.keys(), key=get_name_index):
            if self.name_to_column_idx[name] >= 0:
                columns.append(self.table.columns[self.name_to_column_idx[name]])
                new_name_to_column_idx[name] = (max(new_name_to_column_idx.values()) + 1) if len(new_name_to_column_idx) else 0
            else:
                new_name_to_column_idx[name] = -1
        self.table.columns = columns
        self.name_to_column_idx = new_name_to_column_idx

        if self.key is not None and self.key in info and info[self.key] in self.key_to_row_idx:
            idx = self.key_to_row_idx[info[self.key]]
        elif self.key is not None and self.key not in info and self.key_to_row_idx:
            idx = list(self.key_to_row_idx.values())[-1]
        else:
            self.table.add_row()
            idx = len(self.table.rows) - 1
            if self.key is not None:
                self.key_to_row_idx[info[self.key]] = idx
        for name, value in info.items():
            if self.name_to_column_idx[name] < 0:
                continue
            formatted_value = get_last_matching_value(self.fields, name, "format", "{}")[1].format(value)
            goal = get_last_matching_value(self.fields, name, "goal", None)[1]
            if goal is not None:
                if name not in self.best:
                    self.best[name] = value
                else:
                    diff = (value - self.best[name]) * (-1 if goal == "lower_is_better" else 1)
                    if diff > 0:
                        self.best[name] = value
                        formatted_value = "[green]" + formatted_value + "[/green]"
                    elif diff <= 0:
                        formatted_value = "[red]" + formatted_value + "[/red]"
            self.table.columns[self.name_to_column_idx[name]]._cells[idx] = formatted_value

        self.refresh()

    def finalize(self):
        if self.live is not None:
            self.live.stop()
Пример #2
0
class ProgressBarRich(ProgressBarBase):
    def __init__(self, min_value, max_value, title=None, progress=None, indent=0, parent=None):
        super(ProgressBarRich, self).__init__(min_value, max_value, title=title)
        import rich.progress
        import rich.table
        import rich.tree
        self.console = rich.console.Console(record=True)
        self.parent = parent
        if progress is None:
            self.progress = rich.progress.Progress(
                rich.progress.SpinnerColumn(),
                rich.progress.TextColumn("[progress.description]{task.description}"),
                rich.progress.BarColumn(),
                rich.progress.TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
                # rich.progress.TimeRemainingColumn(),
                TimeElapsedColumn(),
                rich.progress.TextColumn("[red]{task.fields[status]}"),
                console=self.console,
                transient=False,
                expand=False,
            )
        else:
            self.progress = progress
        if parent is None:
            self.node = rich.tree.Tree(self.progress)
            from rich.live import Live
            self.live = Live(self.node, refresh_per_second=5, console=self.console)
        else:
            self.node = parent.add(self.progress)
        # we do 1000 discrete steps
        self.steps = 0
        self.indent = indent

        padding = max(0, 45- (self.indent * 4) - len(self.title))
        self.passes = None
        self.task = self.progress.add_task(f"[red]{self.title}" + (" " * padding), total=1000, start=False, status=self.status or '', passes=self.passes)
        self.started = False
        self.subtasks = []

    def add_child(self, parent, task, title):
        return ProgressBarRich(self.min_value, self.max_value, title, indent=self.indent+1, parent=self.node)

    def __call__(self, value):
        if not self.started:
            self.progress.start_task(self.task)
        if value > self.value:
            steps = int(value * 1000)
            delta = steps - self.steps
            if delta > 0:
                self.progress.update(self.task, advance=delta, refresh=False, passes=self.passes)
            else:
                start_time = self.progress.tasks[0].start_time
                self.progress.reset(self.task, completed=steps, refresh=False, status=self.status or '')
                self.progress.tasks[0].start_time = start_time
            self.steps = steps
        self.value = value

    def update(self, value):
        self(value)

    def finish(self):
        self(self.max_value)
        if self.parent is None:
            self.live.refresh()

    def start(self):
        if self.parent is None and not self.live.is_started:
            self.live.refresh()
            self.live.start()

    def exit(self):
        if self.parent is None:
            self.live.stop()

    def set_status(self, status):
        self.status = status
        self.progress.update(self.task, status=self.status)

    def set_passes(self, passes):
        self.passes = passes