Пример #1
0
def test_download_progress_uses_decimal_units() -> None:

    column = DownloadColumn()
    test_task = Task(1, "test", 1000, 500, _get_time=lambda: 1.0)
    rendered_progress = str(column.render(test_task))
    expected = "0.5/1.0 KB"
    assert rendered_progress == expected
Пример #2
0
def test_download_progress_uses_binary_units() -> None:

    column = DownloadColumn(binary_units=True)
    test_task = Task(1, "test", 1024, 512, _get_time=lambda: 1.0)
    rendered_progress = str(column.render(test_task))
    expected = "0.5/1.0 KiB"
    assert rendered_progress == expected
Пример #3
0
    def __init__(self) -> None:
        self.console = Console(highlight=False)

        self._crawl_progress = Progress(
            TextColumn("{task.description}", table_column=Column(ratio=1)),
            BarColumn(),
            TimeRemainingColumn(),
            expand=True,
        )
        self._download_progress = Progress(
            TextColumn("{task.description}", table_column=Column(ratio=1)),
            TransferSpeedColumn(),
            DownloadColumn(),
            BarColumn(),
            TimeRemainingColumn(),
            expand=True,
        )

        self._live = Live(console=self.console, transient=True)
        self._update_live()

        self._showing_progress = False
        self._progress_suspended = False
        self._lock = asyncio.Lock()
        self._lines: List[str] = []

        # Whether different parts of the output are enabled or disabled
        self.output_explain = False
        self.output_status = True
        self.output_report = True
Пример #4
0
    def start(self, *, total: Optional[float], at: float,
              description: str) -> None:
        from rich.progress import (
            Progress,
            BarColumn,
            DownloadColumn,
            TimeRemainingColumn,
            TransferSpeedColumn,
        )

        assert total is not None
        self.console.print(f'[progress.description]{description}')
        self.progress_bar = Progress(
            '[',
            BarColumn(),
            ']',
            '[progress.percentage]{task.percentage:>3.0f}%',
            '(',
            DownloadColumn(),
            ')',
            TimeRemainingColumn(),
            TransferSpeedColumn(),
            console=self.console,
            transient=True)
        self.progress_bar.start()
        self.transfer_task = self.progress_bar.add_task(description,
                                                        completed=at,
                                                        total=total)
Пример #5
0
def UploadFile(socket, address, key, size_uncompressed, size_compressed, buffer=2048):
    with open("temp.tar.gz", "rb") as f:
        file_hash_uc = SHA512.new();
        file_hash_c = SHA512.new();
        for address_singular in address:
            with open(address_singular, "rb") as filehandle:
                while True:
                    block = filehandle.read(buffer);
                    if not block:
                        break;
                    file_hash_uc.update(block);
        with Progress(TextColumn("[bold blue]{task.description}", justify="right"),
                    BarColumn(bar_width=None),
                    "[progress.percentage]{task.percentage:>3.1f}%",
                    "•",
                    DownloadColumn(),
                    "•",
                    TransferSpeedColumn(),
                    "•",
                    TimeRemainingColumn(),) as progress:
            task = progress.add_task("Uploading file(s)", total=size_compressed);
            while not progress.finished:
                l = f.read(buffer);
                if not l:
                    break;
                select.select([], [socket], []);
                sendEncryptedMessage(socket, l, key);
                progress.update(task, advance=len(l));
                file_hash_c.update(l);
    return (file_hash_uc, file_hash_c);
Пример #6
0
def _get_progress_dl(**kwargs) -> Progress:
    """Track progress of processing a file in bytes"""
    return Progress(
        '[progress.description]{task.description}',
        BarColumn(),
        '[progress.percentage]{task.percentage:>3.0f}%',
        TransferSpeedColumn(),
        DownloadColumn(),
        TimeRemainingColumn(),
        **kwargs,
    )
Пример #7
0
 def __init__(self, client_id, client_secret):
     self.api = Api(client_id, client_secret)
     self.progress = Progress(
         BarColumn(bar_width=None),
         "[progress.percentage]{task.percentage:>3.1f}%",
         DownloadColumn(),
         TransferSpeedColumn(),
         "[bold blue]{task.fields[filename]}",
     )
     self.all_t = self.progress.add_task('All', filename='All', start=0)
     self.total_length = 0;
Пример #8
0
 def __init__(self, *args, **kwargs):
     columns = [
         TextColumn("[progress.description]{task.description}"),
         BarColumn(),
         TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
         DownloadColumn(),
         TimeRemainingColumn(),
         TimeElapsedColumn()
     ]
     kwargs['refresh_per_second'] = 1
     super().__init__(*columns, *args, **kwargs)
Пример #9
0
def get_rich() -> Progress:
    return Progress(
        TextColumn("[bold blue]{task.description}", justify="right"),
        BarColumn(bar_width=None),
        "[progress.percentage]{task.percentage:>3.1f}%",
        "•",
        DownloadColumn(),
        "•",
        TransferSpeedColumn(),
        "•",
        TimeRemainingColumn(),
    )
Пример #10
0
 def __init__(self):
     self.progress = Progress(
         TextColumn("[bold blue]{task.fields[title]}", justify="right"),
         BarColumn(bar_width=None),
         "[progress.percentage]{task.percentage:>3.1f}%",
         "•",
         DownloadColumn(),
         "•",
         TransferSpeedColumn(),
         "•",
         TimeRemainingColumn(),
     )
Пример #11
0
    def run(self, args):

        # Create a progress bar for the download
        progress = Progress(
            TextColumn("[bold cyan]{task.fields[filename]}", justify="right"),
            BarColumn(bar_width=None),
            "[progress.percentage]{task.percentage:>3.1f}%",
            "•",
            DownloadColumn(),
            "•",
            TransferSpeedColumn(),
            "•",
            TimeRemainingColumn(),
        )

        if not args.destination:
            args.destination = f"./{os.path.basename(args.source)}"
        else:
            access = pwncat.victim.access(args.destination)
            if Access.DIRECTORY in access:
                args.destination = os.path.join(args.destination,
                                                os.path.basename(args.source))
            elif Access.PARENT_EXIST not in access:
                console.log(
                    f"[cyan]{args.destination}[/cyan]: no such file or directory"
                )
                return

        try:
            length = os.path.getsize(args.source)
            started = time.time()
            with progress:
                task_id = progress.add_task("upload",
                                            filename=args.destination,
                                            total=length,
                                            start=False)
                with open(args.source, "rb") as source:
                    with pwncat.victim.open(args.destination,
                                            "wb",
                                            length=length) as destination:
                        progress.start_task(task_id)
                        copyfileobj(
                            source,
                            destination,
                            lambda count: progress.update(task_id,
                                                          advance=count),
                        )
            elapsed = time.time() - started
            console.log(f"uploaded [cyan]{human_readable_size(length)}[/cyan] "
                        f"in [green]{human_readable_delta(elapsed)}[/green]")
        except (FileNotFoundError, PermissionError, IsADirectoryError) as exc:
            self.parser.error(str(exc))
Пример #12
0
def create_bar():
    return Progress(
        "[progress.percentage]{task.percentage:>3.0f}%",
        BarColumn(),
        DownloadColumn(),
        TextColumn("[red]["),
        TransferSpeedColumn(),
        TimeRemainingColumn(),
        TextColumn("[cyan]]"),
        TextColumn("[yellow]["),
        TextColumn("[yellow]{task.description}", justify="center"),
        TextColumn("[yellow]]"),
    )
Пример #13
0
def download_one_progress_bar():
    """
    Get progress bar to show progress for downloading a file of known size.
    """
    progress_bar = Progress(
        TextColumn('[bold yellow]Downloading {task.description}'),
        BarColumn(),
        DownloadColumn(),
        TransferSpeedColumn(),
        TimeRemainingColumn(),
    )

    return progress_bar
Пример #14
0
def _download(url: str, root: str):
    os.makedirs(root, exist_ok=True)
    filename = os.path.basename(url)

    download_target = os.path.join(root, filename)
    if os.path.isfile(download_target):
        return download_target

    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise FileExistsError(f'{download_target} exists and is not a regular file')

    from rich.progress import (
        DownloadColumn,
        Progress,
        TextColumn,
        TimeRemainingColumn,
        TransferSpeedColumn,
    )

    progress = Progress(
        TextColumn("[bold blue]{task.fields[filename]}", justify="right"),
        "[progress.percentage]{task.percentage:>3.1f}%",
        "•",
        DownloadColumn(),
        "•",
        TransferSpeedColumn(),
        "•",
        TimeRemainingColumn(),
    )

    with progress:

        task = progress.add_task('download', filename=url, start=False)

        with urllib.request.urlopen(url) as source, open(
            download_target, 'wb'
        ) as output:

            progress.update(task, total=int(source.info().get('Content-Length')))

            progress.start_task(task)
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                progress.update(task, advance=len(buffer))

    return download_target
Пример #15
0
def DataTransformBar(has_size_info: bool = True):
    """
    创建一个数据传输进度条

    Create a data transfer progress bar

    :param has_size_info: 是否有总任务量
    :return: rich.progress.Progress
    """
    from rich.progress import Progress, TextColumn, BarColumn, DownloadColumn, TransferSpeedColumn, TimeRemainingColumn
    from .. import qs_default_console

    if has_size_info:
        return Progress(
            TextColumn("[bold blue]{task.fields[filename]}", justify="right"),
            BarColumn(bar_width=None),
            "[progress.percentage]{task.percentage:>3.1f}%",
            "•",
            DownloadColumn(),
            "•",
            TransferSpeedColumn(),
            "•",
            TimeRemainingColumn(),
            console=qs_default_console
        )
    else:
        from .. import user_lang
        return Progress(
            TextColumn(
                "[bold blue]{task.fields[filename]} [red]" +
                ('Unknow size' if user_lang != 'zh' else '未知大小'),
                justify="right"
            ),
            BarColumn(bar_width=None),
            DownloadColumn(),
            console=qs_default_console
        )
Пример #16
0
    def run(self, manager: "pwncat.manager.Manager", args):

        # Create a progress bar for the download
        progress = Progress(
            TextColumn("[bold cyan]{task.fields[filename]}", justify="right"),
            BarColumn(bar_width=None),
            "[progress.percentage]{task.percentage:>3.1f}%",
            "•",
            DownloadColumn(),
            "•",
            TransferSpeedColumn(),
            "•",
            TimeRemainingColumn(),
        )

        if not args.destination:
            args.destination = f"./{os.path.basename(args.source)}"

        try:
            length = os.path.getsize(args.source)
            started = time.time()
            with progress:
                task_id = progress.add_task("upload",
                                            filename=args.destination,
                                            total=length,
                                            start=False)

                with open(args.source, "rb") as source:
                    with manager.target.platform.open(args.destination,
                                                      "wb") as destination:
                        progress.start_task(task_id)
                        copyfileobj(
                            source,
                            destination,
                            lambda count: progress.update(task_id,
                                                          advance=count),
                        )
                        progress.update(task_id,
                                        filename="draining buffers...")
                        progress.stop_task(task_id)

                    progress.start_task(task_id)
                    progress.update(task_id, filename=args.destination)

            elapsed = time.time() - started
            console.log(f"uploaded [cyan]{human_readable_size(length)}[/cyan] "
                        f"in [green]{human_readable_delta(elapsed)}[/green]")
        except (FileNotFoundError, PermissionError, IsADirectoryError) as exc:
            self.parser.error(str(exc))
Пример #17
0
 def __init__(self, args: CmdArgs):
     self.logger = logging.getLogger('downloader')
     self.args = args
     self.exit = False
     # <---来自命令行的设置--->
     self.max_concurrent_downloads = 1
     # <---进度条--->
     self.progress = Progress(
         TextColumn("[bold blue]{task.fields[name]}", justify="right"),
         BarColumn(bar_width=None),
         "[progress.percentage]{task.percentage:>3.2f}%", "•",
         DownloadColumn(binary_units=True), "•", TransferSpeedColumn(), "•",
         TimeRemainingColumn())
     self.terminate = False
     signal.signal(signal.SIGINT, self.stop)
     signal.signal(signal.SIGTERM, self.stop)
Пример #18
0
 def __init__(self):
     self.exit = False
     # <---来自命令行的设置--->
     self.max_concurrent_downloads = 1
     # <---进度条--->
     self.progress = Progress(
         TextColumn("[bold blue]{task.fields[name]}", justify="right"),
         BarColumn(bar_width=None),
         "[progress.percentage]{task.percentage:>3.2f}%",
         "•",
         DownloadColumn(),
         "•",
         TransferSpeedColumn(),
         "•",
         TimeRemainingColumn(),
     )
Пример #19
0
    def run(self, cfg: DictConfig):
        if not cfg.task.query:
            raise ValueError("Query must be provided.")
        driver, aq = self.sessions(cfg)
        with Progress(
                "[progress.description]{task.description}",
                BarColumn(),
                "[progress.percentage]{task.percentage:>3.0f}%",
                TimeRemainingColumn(),
                TransferSpeedColumn(),
                DownloadColumn(),
        ) as progress:
            n_cpus = cfg.task.n_jobs or os.cpu_count()

            # TASK 0
            logger.info("Requesting Aquarium inventory...")

            match_query = "MATCH (n:Sample) RETURN n.id"
            if cfg.task.query.n_items > -1:
                match_query += " LIMIT {}".format(cfg.task.query.n_items)
            results = driver.read(match_query)
            sample_ids = [r[0] for r in results]

            models = aq.Sample.find(sample_ids)
            logger.info("Found {} samples in graph db".format(len(models)))

            node_payload = aq_inventory_to_cypher(aq, models)

            if cfg.task.strict:

                def error_callback(e: Exception):
                    raise e

            else:
                error_callback = self.catch_constraint_error

            task1 = progress.add_task("adding nodes...")
            # TODO: indicate when creation is skipped
            if cfg.task.create_nodes:
                progress.tasks[task1].total = len(node_payload)
                driver.pool(n_cpus).write(
                    node_payload,
                    callback=lambda _: progress.update(task1, advance=1),
                    chunksize=cfg.task.chunksize,
                    error_callback=error_callback,
                )
                progress.update(task1, completed=progress.tasks[task1].total)
Пример #20
0
def test_columns() -> None:

    console = Console(
        file=io.StringIO(),
        force_terminal=True,
        width=80,
        log_time_format="[TIME]",
        color_system="truecolor",
        legacy_windows=False,
        log_path=False,
        _environ={},
    )
    progress = Progress(
        "test",
        TextColumn("{task.description}"),
        BarColumn(bar_width=None),
        TimeRemainingColumn(),
        TimeElapsedColumn(),
        FileSizeColumn(),
        TotalFileSizeColumn(),
        DownloadColumn(),
        TransferSpeedColumn(),
        MofNCompleteColumn(),
        MofNCompleteColumn(separator=" of "),
        transient=True,
        console=console,
        auto_refresh=False,
        get_time=MockClock(),
    )
    task1 = progress.add_task("foo", total=10)
    task2 = progress.add_task("bar", total=7)
    with progress:
        for n in range(4):
            progress.advance(task1, 3)
            progress.advance(task2, 4)
        print("foo")
        console.log("hello")
        console.print("world")
        progress.refresh()
    from .render import replace_link_ids

    result = replace_link_ids(console.file.getvalue())
    print(repr(result))
    expected = "\x1b[?25ltest foo \x1b[38;5;237m━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[33m0:00:07\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m0/10 bytes\x1b[0m \x1b[31m?\x1b[0m \x1b[32m 0/10\x1b[0m \x1b[32m 0 of 10\x1b[0m\ntest bar \x1b[38;5;237m━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[33m0:00:18\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m0/7 bytes \x1b[0m \x1b[31m?\x1b[0m \x1b[32m0/7  \x1b[0m \x1b[32m0 of 7  \x1b[0m\r\x1b[2K\x1b[1A\x1b[2Kfoo\ntest foo \x1b[38;5;237m━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[33m0:00:07\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m0/10 bytes\x1b[0m \x1b[31m?\x1b[0m \x1b[32m 0/10\x1b[0m \x1b[32m 0 of 10\x1b[0m\ntest bar \x1b[38;5;237m━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[33m0:00:18\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m0/7 bytes \x1b[0m \x1b[31m?\x1b[0m \x1b[32m0/7  \x1b[0m \x1b[32m0 of 7  \x1b[0m\r\x1b[2K\x1b[1A\x1b[2K\x1b[2;36m[TIME]\x1b[0m\x1b[2;36m \x1b[0mhello                                                                    \ntest foo \x1b[38;5;237m━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[33m0:00:07\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m0/10 bytes\x1b[0m \x1b[31m?\x1b[0m \x1b[32m 0/10\x1b[0m \x1b[32m 0 of 10\x1b[0m\ntest bar \x1b[38;5;237m━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[33m0:00:18\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m0/7 bytes \x1b[0m \x1b[31m?\x1b[0m \x1b[32m0/7  \x1b[0m \x1b[32m0 of 7  \x1b[0m\r\x1b[2K\x1b[1A\x1b[2Kworld\ntest foo \x1b[38;5;237m━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[33m0:00:07\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m0/10 bytes\x1b[0m \x1b[31m?\x1b[0m \x1b[32m 0/10\x1b[0m \x1b[32m 0 of 10\x1b[0m\ntest bar \x1b[38;5;237m━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[33m0:00:18\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m0/7 bytes \x1b[0m \x1b[31m?\x1b[0m \x1b[32m0/7  \x1b[0m \x1b[32m0 of 7  \x1b[0m\r\x1b[2K\x1b[1A\x1b[2Ktest foo \x1b[38;2;114;156;31m━━━━━━━\x1b[0m \x1b[36m0:00:00\x1b[0m \x1b[33m0:00:34\x1b[0m \x1b[32m12     \x1b[0m \x1b[32m10     \x1b[0m \x1b[32m12/10   \x1b[0m \x1b[31m1      \x1b[0m \x1b[32m12/10\x1b[0m \x1b[32m12 of 10\x1b[0m\n                                 \x1b[32mbytes  \x1b[0m \x1b[32mbytes  \x1b[0m \x1b[32mbytes   \x1b[0m \x1b[31mbyte/s \x1b[0m               \ntest bar \x1b[38;2;114;156;31m━━━━━━━\x1b[0m \x1b[36m0:00:00\x1b[0m \x1b[33m0:00:29\x1b[0m \x1b[32m16     \x1b[0m \x1b[32m7 bytes\x1b[0m \x1b[32m16/7    \x1b[0m \x1b[31m2      \x1b[0m \x1b[32m16/7 \x1b[0m \x1b[32m16 of 7 \x1b[0m\n                                 \x1b[32mbytes  \x1b[0m         \x1b[32mbytes   \x1b[0m \x1b[31mbytes/s\x1b[0m               \r\x1b[2K\x1b[1A\x1b[2K\x1b[1A\x1b[2K\x1b[1A\x1b[2Ktest foo \x1b[38;2;114;156;31m━━━━━━━\x1b[0m \x1b[36m0:00:00\x1b[0m \x1b[33m0:00:34\x1b[0m \x1b[32m12     \x1b[0m \x1b[32m10     \x1b[0m \x1b[32m12/10   \x1b[0m \x1b[31m1      \x1b[0m \x1b[32m12/10\x1b[0m \x1b[32m12 of 10\x1b[0m\n                                 \x1b[32mbytes  \x1b[0m \x1b[32mbytes  \x1b[0m \x1b[32mbytes   \x1b[0m \x1b[31mbyte/s \x1b[0m               \ntest bar \x1b[38;2;114;156;31m━━━━━━━\x1b[0m \x1b[36m0:00:00\x1b[0m \x1b[33m0:00:29\x1b[0m \x1b[32m16     \x1b[0m \x1b[32m7 bytes\x1b[0m \x1b[32m16/7    \x1b[0m \x1b[31m2      \x1b[0m \x1b[32m16/7 \x1b[0m \x1b[32m16 of 7 \x1b[0m\n                                 \x1b[32mbytes  \x1b[0m         \x1b[32mbytes   \x1b[0m \x1b[31mbytes/s\x1b[0m               \n\x1b[?25h\r\x1b[1A\x1b[2K\x1b[1A\x1b[2K\x1b[1A\x1b[2K\x1b[1A\x1b[2K"

    assert result == expected
Пример #21
0
def retrieve_over_http(url, output_file_path):
    """Download file from remote location, with progress bar.

    Parameters
    ----------
    url : str
        Remote URL.
    output_file_path : str or Path
        Full file destination for download.

    """
    # Make Rich progress bar
    progress = Progress(
        TextColumn("[bold]Downloading...", justify="right"),
        BarColumn(bar_width=None),
        "{task.percentage:>3.1f}%",
        "•",
        DownloadColumn(),
        "• speed:",
        TransferSpeedColumn(),
        "• ETA:",
        TimeRemainingColumn(),
    )

    CHUNK_SIZE = 4096
    response = requests.get(url, stream=True)

    try:
        with progress:
            task_id = progress.add_task(
                "download",
                filename=output_file_path.name,
                start=True,
                total=int(response.headers.get("content-length", 0)),
            )

            with open(output_file_path, "wb") as fout:
                for chunk in response.iter_content(chunk_size=CHUNK_SIZE):
                    fout.write(chunk)
                    progress.update(task_id, advance=len(chunk), refresh=True)

    except requests.exceptions.ConnectionError:
        output_file_path.unlink()
        raise requests.exceptions.ConnectionError(
            f"Could not download file from {url}"
        )
Пример #22
0
def DownloadFile(socket, name, key, size_uncompressed, size_compressed, buffer=2048):
    with open("temp.tar.gz", "wb") as f:
        file_hash = SHA512.new();
        with Progress(TextColumn("[bold blue]{task.description}", justify="right"),
                    BarColumn(bar_width=None),
                    "[progress.percentage]{task.percentage:>3.1f}%",
                    "•",
                    DownloadColumn(),
                    "•",
                    TransferSpeedColumn(),
                    "•",
                    TimeRemainingColumn(),) as progress:
            task = progress.add_task(f"Downloading file(s)", total=size_compressed);
            while not progress.finished:
                select.select([client_socket], [], []);
                user_data = receive_message(socket);
                l = recieveEncryptedMessage(socket, key)["data"];
                f.write(l);
                progress.update(task, advance=len(l));
                file_hash.update(l);
        user_data = receive_message(socket);
        l = recieveEncryptedMessage(socket, key)["data"];
        if l[:8] == "SFTP END".encode('utf-8'):
            print(f"{RT.BLUE}SFTP END{RT.RESET}");
        else:
            print(f"{RT.RED}SFTP Did Not End! Retry File Transfer.{RT.End}");
            return;
        split_data = l.split(CUSTOM_SEPARATOR);
        received_file_hash_uc = split_data[1].decode('utf-8');
        received_file_hash_c = split_data[2].decode('utf-8');
        if received_file_hash_c == file_hash.hexdigest():
            FileDecompressor("temp.tar.gz", name);
            ucfilehash = SHA512.new();
            for name_singular in name:
                with open(name_singular, "rb") as filehandle:
                    while True:
                        block = filehandle.read(buffer);
                        if not block:
                            break;
                        ucfilehash.update(block);
        if received_file_hash_c == file_hash.hexdigest() and received_file_hash_uc == ucfilehash.hexdigest():
            print(f"{RT.GREEN}SFTP Checksum Matched!{RT.RESET}");
        else:
            print(f"{RT.RED}SFTP Checksum Did Not Match! File Is Corrupt{RT.RESET}");
Пример #23
0
def test_columns() -> None:
    time = 1.0

    def get_time():
        nonlocal time
        time += 1.0
        return time

    console = Console(
        file=io.StringIO(),
        force_terminal=True,
        width=80,
        log_time_format="[TIME]",
        color_system="truecolor",
    )
    progress = Progress(
        "test",
        TextColumn("{task.description}"),
        BarColumn(bar_width=None),
        TimeRemainingColumn(),
        FileSizeColumn(),
        TotalFileSizeColumn(),
        DownloadColumn(),
        TransferSpeedColumn(),
        console=console,
        auto_refresh=False,
        get_time=get_time,
    )
    task1 = progress.add_task("foo", total=10)
    task2 = progress.add_task("bar", total=7)
    with progress:
        for n in range(4):
            progress.advance(task1, 3)
            progress.advance(task2, 4)
        progress.log("hello")
        progress.print("world")
        progress.refresh()
    result = console.file.getvalue()
    print(repr(result))
    expected = 'test foo \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m0/10 bytes\x1b[0m \x1b[31m?\x1b[0m \r\x1b[2Ktest foo \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m0/10 bytes\x1b[0m \x1b[31m?\x1b[0m \ntest bar \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m0/7 bytes \x1b[0m \x1b[31m?\x1b[0m \x1b[?25l\r\x1b[1A\x1b[2Ktest foo \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m0/10 bytes\x1b[0m \x1b[31m?\x1b[0m \ntest bar \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m0/7 bytes \x1b[0m \x1b[31m?\x1b[0m \r\x1b[1A\x1b[2K\x1b[2;36m[TIME]\x1b[0m\x1b[2;36m \x1b[0mhello                                               \x1b[2mtest_progress.py:190\x1b[0m\x1b[2m \x1b[0m\ntest foo \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m0/10 bytes\x1b[0m \x1b[31m?\x1b[0m \ntest bar \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m0/7 bytes \x1b[0m \x1b[31m?\x1b[0m \r\x1b[1A\x1b[2Kworld\ntest foo \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m0/10 bytes\x1b[0m \x1b[31m?\x1b[0m \ntest bar \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m-:--:--\x1b[0m \x1b[32m0 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m0/7 bytes \x1b[0m \x1b[31m?\x1b[0m \r\x1b[1A\x1b[2Ktest foo \x1b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m0:00:00\x1b[0m \x1b[32m12 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m12/10 bytes\x1b[0m \x1b[31m1 byte/s \x1b[0m \ntest bar \x1b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m0:00:00\x1b[0m \x1b[32m16 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m16/7 bytes \x1b[0m \x1b[31m2 bytes/s\x1b[0m \r\x1b[1A\x1b[2Ktest foo \x1b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m0:00:00\x1b[0m \x1b[32m12 bytes\x1b[0m \x1b[32m10 bytes\x1b[0m \x1b[32m12/10 bytes\x1b[0m \x1b[31m1 byte/s \x1b[0m \ntest bar \x1b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[36m0:00:00\x1b[0m \x1b[32m16 bytes\x1b[0m \x1b[32m7 bytes \x1b[0m \x1b[32m16/7 bytes \x1b[0m \x1b[31m2 bytes/s\x1b[0m \n\x1b[?25h'
    assert result == expected
Пример #24
0
async def main(url=None, file=None, task_count=None):
    if url is None:
        url = input('url: ')
    if file is None:
        file = input('file: ')
    if task_count is None:
        task_count = int(input('task count (default 10): ').strip() or 10)
    progress = Progress(
        TextColumn('[bold blue]{task.description}', justify='right'),
        BarColumn(bar_width=None),
        "[progress.percentage]{task.percentage:>3.1f}%",
        "•",
        DownloadColumn(),
        "•",
        TransferSpeedColumn(),
        "•",
        TimeRemainingColumn(),
        refresh_per_second=1
    )
    with open(file, 'wb') as fp, progress:
        async with Downloader(url, fp, {}, task_count) as downloader:
            t_tasks = [progress.add_task('Total', total=downloader.filesize)]
            bytes_last_task = (downloader.filesize -
                               downloader.bytes_per_task * (downloader.task_count - 1))
            for i in range(1, downloader.task_count + 1):
                t_tasks.append(progress.add_task(
                    f'{i}',
                    total=downloader.bytes_per_task
                    if i != downloader.task_count
                    else bytes_last_task
                ))

            def report_hook(task_id, task_downloaded, total_downloaded):
                progress.update(t_tasks[task_id], completed=task_downloaded)
                progress.update(t_tasks[0], completed=total_downloaded)
            await downloader.download(report_hook)
    input('Finished... ')
Пример #25
0
    def _pull_with_progress(self, log_streams, console):
        from rich.progress import Progress, DownloadColumn, BarColumn

        with Progress(
                "[progress.description]{task.description}",
                BarColumn(),
                DownloadColumn(),
                "[progress.percentage]{task.percentage:>3.0f}%",
                console=console,
                transient=True,
        ) as progress:
            tasks = {}
            for log in log_streams:
                if 'status' not in log:
                    continue
                status = log['status']
                status_id = log.get('id', None)
                pg_detail = log.get('progressDetail', None)

                if (pg_detail is None) or (status_id is None):
                    self.logger.debug(status)
                    continue

                if status_id not in tasks:
                    tasks[status_id] = progress.add_task(status, total=0)

                task_id = tasks[status_id]

                if ('current' in pg_detail) and ('total' in pg_detail):
                    progress.update(
                        task_id,
                        completed=pg_detail['current'],
                        total=pg_detail['total'],
                        description=status,
                    )
                elif not pg_detail:
                    progress.update(task_id, advance=0, description=status)
Пример #26
0
class RichTransferProgress(RichProgress):
    SUMMARY_COLS = (
        TextColumn("[magenta]{task.description}[bold green]"),
        MofNCompleteColumnWithUnit(),
        TimeElapsedColumn(),
    )
    TRANSFER_COLS = (
        TextColumn("  [blue]{task.description}"),
        BarColumn(),
        DownloadColumn(),
        TransferSpeedColumn(),
        TextColumn("eta"),
        TimeRemainingColumn(),
    )

    def get_renderables(self):
        summary_tasks, other_tasks = split(
            lambda task: task.fields.get("progress_type") == "summary",
            self.tasks,
        )
        self.columns = self.SUMMARY_COLS
        yield self.make_tasks_table(summary_tasks)
        self.columns = self.TRANSFER_COLS
        yield self.make_tasks_table(other_tasks)
Пример #27
0
from rich.progress import (
    BarColumn,
    DownloadColumn,
    TextColumn,
    TransferSpeedColumn,
    TimeRemainingColumn,
    Progress,
    TaskID,
)

progress = Progress(
    TextColumn("[bold blue]{task.fields[filename]}", justify="right"),
    BarColumn(bar_width=None),
    "[progress.percentage]{task.percentage:>3.1f}%",
    "•",
    DownloadColumn(),
    "•",
    TransferSpeedColumn(),
    "•",
    TimeRemainingColumn(),
)


def copy_url(task_id: TaskID, url: str, path: str) -> None:
    """Copy data from a url to a local file."""
    response = urlopen(url)
    # This will break if the response doesn't contain content length
    progress.update(task_id, total=int(response.info()["Content-length"]))
    with open(path, "wb") as dest_file:
        progress.start_task(task_id)
        for data in iter(partial(response.read, 32768), b""):
Пример #28
0
    def run(self, cfg: Optional[DictConfig] = None):
        """Update the graphdb using some Aquarium sample query.

        :param cfg: the configuration
        :return: None
        """
        driver, aq = self.sessions(cfg)

        logger.info("Requesting Aquarium inventory...")
        if not cfg.task.query:
            raise ValueError("Query must be provided.")
        query = {}
        if cfg.task.query.user:
            user = aq.User.where({"login": cfg.task.query.user})[0]
            query["user_id"] = user.id
        n_samples = cfg.task.query.n_samples
        models = aq.Sample.last(n_samples, query)

        with Progress(
            "[progress.description]{task.description}",
            BarColumn(),
            "[progress.percentage]{task.percentage:>3.0f}%",
            TimeRemainingColumn(),
            TransferSpeedColumn(),
            DownloadColumn(),
        ) as progress:
            n_cpus = cfg.task.n_jobs or os.cpu_count()

            task0 = progress.add_task(
                "[blue]collecting Aquarium samples[/blue]", total=100
            )

            if cfg.task.strict:

                def error_callback(e: Exception):
                    raise e
            else:
                error_callback = self.catch_constraint_error

            if cfg.task.create_nodes:
                task1 = progress.add_task(
                    "[red]writing nodes to [bold]neo4j[/bold]...[/red] ([green]cpus: {cpus}[/green])".format(
                        cpus=n_cpus
                    )
                )
            else:
                task1 = None

            with infinite_task_context(progress, task0) as callback:
                node_payload = aq_samples_to_cypher(
                    aq,
                    models,
                    new_node_callback=callback
                )

            if cfg.task.create_nodes:
                progress.tasks[task1].total = len(node_payload)
                driver.pool(n_cpus).write(
                    node_payload,
                    callback=lambda x: progress.update(task1, advance=len(x)),
                    chunksize=cfg.task.chunksize,
                    error_callback=error_callback,
                )
                progress.update(task1, completed=progress.tasks[task1].total)
Пример #29
0
class YataiClient:
    log_progress = ProgressWrapper(Progress(
        TextColumn("{task.description}"), ))

    spinner_progress = ProgressWrapper(
        Progress(
            TextColumn("  "),
            TimeElapsedColumn(),
            TextColumn("[bold purple]{task.fields[action]}"),
            SpinnerColumn("simpleDots"),
        ))

    transmission_progress = ProgressWrapper(
        Progress(
            TextColumn("[bold blue]{task.description}", justify="right"),
            BarColumn(bar_width=None),
            "[progress.percentage]{task.percentage:>3.1f}%",
            "•",
            DownloadColumn(),
            "•",
            TransferSpeedColumn(),
            "•",
            TimeRemainingColumn(),
        ))

    progress_group = Group(Panel(Group(log_progress, spinner_progress)),
                           transmission_progress)

    @contextmanager
    def spin(self, *, text: str):
        task_id = self.spinner_progress.add_task("", action=text)
        try:
            yield
        finally:
            self.spinner_progress.stop_task(task_id)
            self.spinner_progress.update(task_id, visible=False)

    @inject
    def push_bento(
        self,
        bento: "Bento",
        *,
        force: bool = False,
        model_store: "ModelStore" = Provide[BentoMLContainer.model_store],
    ):
        with Live(self.progress_group):
            upload_task_id = self.transmission_progress.add_task(
                f'Pushing Bento "{bento.tag}"', start=False, visible=False)
            self._do_push_bento(bento,
                                upload_task_id,
                                force=force,
                                model_store=model_store)

    @inject
    def _do_push_bento(
        self,
        bento: "Bento",
        upload_task_id: TaskID,
        *,
        force: bool = False,
        model_store: "ModelStore" = Provide[BentoMLContainer.model_store],
    ):
        yatai_rest_client = get_current_yatai_rest_api_client()
        name = bento.tag.name
        version = bento.tag.version
        if version is None:
            raise BentoMLException(f"Bento {bento.tag} version cannot be None")
        info = bento.info
        model_names = info.models
        with ThreadPoolExecutor(
                max_workers=max(len(model_names), 1)) as executor:

            def push_model(model: "Model"):
                model_upload_task_id = self.transmission_progress.add_task(
                    f'Pushing model "{model.tag}"', start=False, visible=False)
                self._do_push_model(model, model_upload_task_id, force=force)

            futures = executor.map(push_model, (model_store.get(name)
                                                for name in model_names))
            list(futures)
        with self.spin(text=f'Fetching Bento repository "{name}"'):
            bento_repository = yatai_rest_client.get_bento_repository(
                bento_repository_name=name)
        if not bento_repository:
            with self.spin(
                    text=f'Bento repository "{name}" not found, creating now..'
            ):
                bento_repository = yatai_rest_client.create_bento_repository(
                    req=CreateBentoRepositorySchema(name=name, description=""))
        with self.spin(text=f'Try fetching Bento "{bento.tag}" from Yatai..'):
            remote_bento = yatai_rest_client.get_bento(
                bento_repository_name=name, version=version)
        if (not force and remote_bento
                and remote_bento.upload_status == BentoUploadStatus.SUCCESS):
            self.log_progress.add_task(
                f'[bold blue]Push failed: Bento "{bento.tag}" already exists in Yatai'
            )
            return
        if not remote_bento:
            labels: t.List[LabelItemSchema] = [
                LabelItemSchema(key=key, value=value)
                for key, value in info.labels.items()
            ]
            apis: t.Dict[str, BentoApiSchema] = {}
            models = [str(m) for m in info.models]
            with self.spin(
                    text=f'Registering Bento "{bento.tag}" with Yatai..'):
                yatai_rest_client.create_bento(
                    bento_repository_name=bento_repository.name,
                    req=CreateBentoSchema(
                        description="",
                        version=version,
                        build_at=info.creation_time,
                        manifest=BentoManifestSchema(
                            service=info.service,
                            bentoml_version=info.bentoml_version,
                            apis=apis,
                            models=models,
                            size_bytes=calc_dir_size(bento.path),
                        ),
                        labels=labels,
                    ),
                )
        with self.spin(
                text=f'Getting a presigned upload url for "{bento.tag}" ..'):
            remote_bento = yatai_rest_client.presign_bento_upload_url(
                bento_repository_name=bento_repository.name, version=version)
        with io.BytesIO() as tar_io:
            bento_dir_path = bento.path
            if bento_dir_path is None:
                raise BentoMLException(f'Bento "{bento}" path cannot be None')
            with self.spin(
                    text=f'Creating tar archive for Bento "{bento.tag}"..'):
                with tarfile.open(fileobj=tar_io, mode="w:gz") as tar:

                    def filter_(
                        tar_info: tarfile.TarInfo,
                    ) -> t.Optional[tarfile.TarInfo]:
                        if tar_info.path == "./models" or tar_info.path.startswith(
                                "./models/"):
                            return None
                        return tar_info

                    tar.add(bento_dir_path, arcname="./", filter=filter_)
            tar_io.seek(0, 0)
            with self.spin(text=f'Start uploading Bento "{bento.tag}"..'):
                yatai_rest_client.start_upload_bento(
                    bento_repository_name=bento_repository.name,
                    version=version)

            file_size = tar_io.getbuffer().nbytes

            self.transmission_progress.update(upload_task_id,
                                              completed=0,
                                              total=file_size,
                                              visible=True)
            self.transmission_progress.start_task(upload_task_id)

            def io_cb(x: int):
                self.transmission_progress.update(upload_task_id, advance=x)

            wrapped_file = CallbackIOWrapper(
                io_cb,
                tar_io,
                "read",
            )
            finish_req = FinishUploadBentoSchema(
                status=BentoUploadStatus.SUCCESS,
                reason="",
            )
            try:
                resp = requests.put(remote_bento.presigned_upload_url,
                                    data=wrapped_file)
                if resp.status_code != 200:
                    finish_req = FinishUploadBentoSchema(
                        status=BentoUploadStatus.FAILED,
                        reason=resp.text,
                    )
            except Exception as e:  # pylint: disable=broad-except
                finish_req = FinishUploadBentoSchema(
                    status=BentoUploadStatus.FAILED,
                    reason=str(e),
                )
            if finish_req.status is BentoUploadStatus.FAILED:
                self.log_progress.add_task(
                    f'[bold red]Failed to upload Bento "{bento.tag}"')
            with self.spin(text="Submitting upload status to Yatai"):
                yatai_rest_client.finish_upload_bento(
                    bento_repository_name=bento_repository.name,
                    version=version,
                    req=finish_req,
                )
            if finish_req.status != BentoUploadStatus.SUCCESS:
                self.log_progress.add_task(
                    f'[bold red]Failed pushing Bento "{bento.tag}": {finish_req.reason}'
                )
            else:
                self.log_progress.add_task(
                    f'[bold green]Successfully pushed Bento "{bento.tag}"')

    @inject
    def pull_bento(
        self,
        tag: t.Union[str, Tag],
        *,
        force: bool = False,
        bento_store: "BentoStore" = Provide[BentoMLContainer.bento_store],
        model_store: "ModelStore" = Provide[BentoMLContainer.model_store],
    ) -> "Bento":
        with Live(self.progress_group):
            download_task_id = self.transmission_progress.add_task(
                f'Pulling bento "{tag}"', start=False, visible=False)
            return self._do_pull_bento(
                tag,
                download_task_id,
                force=force,
                bento_store=bento_store,
                model_store=model_store,
            )

    @inject
    def _do_pull_bento(
        self,
        tag: t.Union[str, Tag],
        download_task_id: TaskID,
        *,
        force: bool = False,
        bento_store: "BentoStore" = Provide[BentoMLContainer.bento_store],
        model_store: "ModelStore" = Provide[BentoMLContainer.model_store],
    ) -> "Bento":
        try:
            bento = bento_store.get(tag)
            if not force:
                self.log_progress.add_task(
                    f'[bold blue]Bento "{tag}" exists in local model store')
                return bento
            bento_store.delete(tag)
        except NotFound:
            pass
        _tag = Tag.from_taglike(tag)
        name = _tag.name
        version = _tag.version
        if version is None:
            raise BentoMLException(f'Bento "{_tag}" version can not be None')
        yatai_rest_client = get_current_yatai_rest_api_client()
        with self.spin(text=f'Fetching bento "{_tag}"'):
            remote_bento = yatai_rest_client.get_bento(
                bento_repository_name=name, version=version)
        if not remote_bento:
            raise BentoMLException(f'Bento "{_tag}" not found on Yatai')
        with ThreadPoolExecutor(max_workers=max(
                len(remote_bento.manifest.models), 1)) as executor:

            def pull_model(model_tag: Tag):
                model_download_task_id = self.transmission_progress.add_task(
                    f'Pulling model "{model_tag}"', start=False, visible=False)
                self._do_pull_model(
                    model_tag,
                    model_download_task_id,
                    force=force,
                    model_store=model_store,
                )

            futures = executor.map(pull_model, remote_bento.manifest.models)
            list(futures)
        with self.spin(
                text=f'Getting a presigned download url for bento "{_tag}"'):
            remote_bento = yatai_rest_client.presign_bento_download_url(
                name, version)
        url = remote_bento.presigned_download_url
        response = requests.get(url, stream=True)
        if response.status_code != 200:
            raise BentoMLException(
                f'Failed to download bento "{_tag}": {response.text}')
        total_size_in_bytes = int(response.headers.get("content-length", 0))
        block_size = 1024  # 1 Kibibyte
        with NamedTemporaryFile() as tar_file:
            self.transmission_progress.update(download_task_id,
                                              completed=0,
                                              total=total_size_in_bytes,
                                              visible=True)
            self.transmission_progress.start_task(download_task_id)
            for data in response.iter_content(block_size):
                self.transmission_progress.update(download_task_id,
                                                  advance=len(data))
                tar_file.write(data)
            self.log_progress.add_task(
                f'[bold green]Finished downloading all bento "{_tag}" files')
            tar_file.seek(0, 0)
            tar = tarfile.open(fileobj=tar_file, mode="r:gz")
            with fs.open_fs("temp://") as temp_fs:
                for member in tar.getmembers():
                    f = tar.extractfile(member)
                    if f is None:
                        continue
                    p = Path(member.name)
                    if p.parent != Path("."):
                        temp_fs.makedirs(str(p.parent), recreate=True)
                    temp_fs.writebytes(member.name, f.read())
                bento = Bento.from_fs(temp_fs)
                for model_tag in remote_bento.manifest.models:
                    with self.spin(
                            text=f'Copying model "{model_tag}" to bento'):
                        copy_model(
                            model_tag,
                            src_model_store=model_store,
                            target_model_store=bento.
                            _model_store,  # type: ignore
                        )
                bento = bento.save(bento_store)
                self.log_progress.add_task(
                    f'[bold green]Successfully pulled bento "{_tag}"')
                return bento

    def push_model(self, model: "Model", *, force: bool = False):
        with Live(self.progress_group):
            upload_task_id = self.transmission_progress.add_task(
                f'Pushing model "{model.tag}"', start=False, visible=False)
            self._do_push_model(model, upload_task_id, force=force)

    def _do_push_model(self,
                       model: "Model",
                       upload_task_id: TaskID,
                       *,
                       force: bool = False):
        yatai_rest_client = get_current_yatai_rest_api_client()
        name = model.tag.name
        version = model.tag.version
        if version is None:
            raise BentoMLException(
                f'Model "{model.tag}" version cannot be None')
        info = model.info
        with self.spin(text=f'Fetching model repository "{name}"'):
            model_repository = yatai_rest_client.get_model_repository(
                model_repository_name=name)
        if not model_repository:
            with self.spin(
                    text=f'Model repository "{name}" not found, creating now..'
            ):
                model_repository = yatai_rest_client.create_model_repository(
                    req=CreateModelRepositorySchema(name=name, description=""))
        with self.spin(text=f'Try fetching model "{model.tag}" from Yatai..'):
            remote_model = yatai_rest_client.get_model(
                model_repository_name=name, version=version)
        if (not force and remote_model
                and remote_model.upload_status == ModelUploadStatus.SUCCESS):
            self.log_progress.add_task(
                f'[bold blue]Model "{model.tag}" already exists in Yatai, skipping'
            )
            return
        if not remote_model:
            labels: t.List[LabelItemSchema] = [
                LabelItemSchema(key=key, value=value)
                for key, value in info.labels.items()
            ]
            with self.spin(
                    text=f'Registering model "{model.tag}" with Yatai..'):
                yatai_rest_client.create_model(
                    model_repository_name=model_repository.name,
                    req=CreateModelSchema(
                        description="",
                        version=version,
                        build_at=info.creation_time,
                        manifest=ModelManifestSchema(
                            module=info.module,
                            metadata=info.metadata,
                            context=info.context,
                            options=info.options,
                            api_version=info.api_version,
                            bentoml_version=info.bentoml_version,
                            size_bytes=calc_dir_size(model.path),
                        ),
                        labels=labels,
                    ),
                )
        with self.spin(
                text=f'Getting a presigned upload url for model "{model.tag}"..'
        ):
            remote_model = yatai_rest_client.presign_model_upload_url(
                model_repository_name=model_repository.name, version=version)
        with io.BytesIO() as tar_io:
            bento_dir_path = model.path
            with self.spin(
                    text=f'Creating tar archive for model "{model.tag}"..'):
                with tarfile.open(fileobj=tar_io, mode="w:gz") as tar:
                    tar.add(bento_dir_path, arcname="./")
            tar_io.seek(0, 0)
            with self.spin(text=f'Start uploading model "{model.tag}"..'):
                yatai_rest_client.start_upload_model(
                    model_repository_name=model_repository.name,
                    version=version)
            file_size = tar_io.getbuffer().nbytes
            self.transmission_progress.update(
                upload_task_id,
                description=f'Uploading model "{model.tag}"',
                total=file_size,
                visible=True,
            )
            self.transmission_progress.start_task(upload_task_id)

            def io_cb(x: int):
                self.transmission_progress.update(upload_task_id, advance=x)

            wrapped_file = CallbackIOWrapper(
                io_cb,
                tar_io,
                "read",
            )
            finish_req = FinishUploadModelSchema(
                status=ModelUploadStatus.SUCCESS,
                reason="",
            )
            try:
                resp = requests.put(remote_model.presigned_upload_url,
                                    data=wrapped_file)
                if resp.status_code != 200:
                    finish_req = FinishUploadModelSchema(
                        status=ModelUploadStatus.FAILED,
                        reason=resp.text,
                    )
            except Exception as e:  # pylint: disable=broad-except
                finish_req = FinishUploadModelSchema(
                    status=ModelUploadStatus.FAILED,
                    reason=str(e),
                )
            if finish_req.status is ModelUploadStatus.FAILED:
                self.log_progress.add_task(
                    f'[bold red]Failed to upload model "{model.tag}"')
            with self.spin(text="Submitting upload status to Yatai"):
                yatai_rest_client.finish_upload_model(
                    model_repository_name=model_repository.name,
                    version=version,
                    req=finish_req,
                )
            if finish_req.status != ModelUploadStatus.SUCCESS:
                self.log_progress.add_task(
                    f'[bold red]Failed pushing model "{model.tag}" : {finish_req.reason}'
                )
            else:
                self.log_progress.add_task(
                    f'[bold green]Successfully pushed model "{model.tag}"')

    @inject
    def pull_model(
        self,
        tag: t.Union[str, Tag],
        *,
        force: bool = False,
        model_store: "ModelStore" = Provide[BentoMLContainer.model_store],
    ) -> "Model":
        with Live(self.progress_group):
            download_task_id = self.transmission_progress.add_task(
                f'Pulling model "{tag}"', start=False, visible=False)
            return self._do_pull_model(tag,
                                       download_task_id,
                                       force=force,
                                       model_store=model_store)

    @inject
    def _do_pull_model(
        self,
        tag: t.Union[str, Tag],
        download_task_id: TaskID,
        *,
        force: bool = False,
        model_store: "ModelStore" = Provide[BentoMLContainer.model_store],
    ) -> "Model":
        try:
            model = model_store.get(tag)
            if not force:
                self.log_progress.add_task(
                    f'[bold blue]Model "{tag}" already exists locally, skipping'
                )
                return model
            else:
                model_store.delete(tag)
        except NotFound:
            pass
        yatai_rest_client = get_current_yatai_rest_api_client()
        _tag = Tag.from_taglike(tag)
        name = _tag.name
        version = _tag.version
        if version is None:
            raise BentoMLException(f'Model "{_tag}" version cannot be None')
        with self.spin(
                text=f'Getting a presigned download url for model "{_tag}"..'):
            remote_model = yatai_rest_client.presign_model_download_url(
                name, version)
        if not remote_model:
            raise BentoMLException(f'Model "{_tag}" not found on Yatai')
        url = remote_model.presigned_download_url
        response = requests.get(url, stream=True)
        if response.status_code != 200:
            raise BentoMLException(
                f'Failed to download model "{_tag}": {response.text}')
        total_size_in_bytes = int(response.headers.get("content-length", 0))
        block_size = 1024  # 1 Kibibyte
        with NamedTemporaryFile() as tar_file:
            self.transmission_progress.update(
                download_task_id,
                description=f'Downloading model "{_tag}"',
                total=total_size_in_bytes,
                visible=True,
            )
            self.transmission_progress.start_task(download_task_id)
            for data in response.iter_content(block_size):
                self.transmission_progress.update(download_task_id,
                                                  advance=len(data))
                tar_file.write(data)
            self.log_progress.add_task(
                f'[bold green]Finished downloading model "{_tag}" files')
            tar_file.seek(0, 0)
            tar = tarfile.open(fileobj=tar_file, mode="r:gz")
            with fs.open_fs("temp://") as temp_fs:
                for member in tar.getmembers():
                    f = tar.extractfile(member)
                    if f is None:
                        continue
                    p = Path(member.name)
                    if p.parent != Path("."):
                        temp_fs.makedirs(str(p.parent), recreate=True)
                    temp_fs.writebytes(member.name, f.read())
                model = Model.from_fs(temp_fs).save(model_store)
                self.log_progress.add_task(
                    f'[bold green]Successfully pulled model "{_tag}"')
                return model
Пример #30
0
def upload_data(
    dataset_identifier: str,
    files: Optional[List[Union[PathLike, LocalFile]]],
    files_to_exclude: Optional[List[PathLike]],
    fps: int,
    path: Optional[str],
    frames: bool,
    preserve_folders: bool = False,
    verbose: bool = False,
) -> None:
    """
    Uploads the provided files to the remote dataset.
    Exits the application if no dataset with the given name is found, the files in the given path
    have unsupported formats, or if there are no files found in the given Path.

    Parameters
    ----------
    dataset_identifier : str
        Slug of the dataset to retrieve.
    files : List[Union[PathLike, LocalFile]]
        List of files to upload. Can be None.
    files_to_exclude : List[PathLike]
        List of files to exclude from the file scan (which is done only if files is None).
    fps : int
        Frame rate to split videos in.
    path : Optional[str]
        If provided; files will be placed under this path in the v7 platform. If `preserve_folders`
        is `True` then it must be possible to draw a relative path from this folder to the one the
        files are in, otherwise an error will be raised.
    frames : bool
        Specify whether the files will be uploaded as a list of frames or not.
    preserve_folders : bool
        Specify whether or not to preserve folder paths when uploading.
    verbose : bool
        Specify whther to have full traces print when uploading files or not.
    """
    client: Client = _load_client()
    try:
        max_workers: int = concurrent.futures.ThreadPoolExecutor(
        )._max_workers  # type: ignore

        dataset: RemoteDataset = client.get_remote_dataset(
            dataset_identifier=dataset_identifier)

        sync_metadata: Progress = Progress(
            SpinnerColumn(), TextColumn("[bold blue]Syncing metadata"))

        overall_progress = Progress(
            TextColumn("[bold blue]{task.fields[filename]}"), BarColumn(),
            "{task.completed} of {task.total}")

        file_progress = Progress(
            TextColumn("[bold green]{task.fields[filename]}", justify="right"),
            BarColumn(),
            "[progress.percentage]{task.percentage:>3.1f}%",
            DownloadColumn(),
            "•",
            TransferSpeedColumn(),
            "•",
            TimeRemainingColumn(),
        )

        progress_table: Table = Table.grid()
        progress_table.add_row(sync_metadata)
        progress_table.add_row(file_progress)
        progress_table.add_row(overall_progress)
        with Live(progress_table):
            sync_task: TaskID = sync_metadata.add_task("")
            file_tasks: Dict[str, TaskID] = {}
            overall_task = overall_progress.add_task("[green]Total progress",
                                                     filename="Total progress",
                                                     total=0,
                                                     visible=False)

            def progress_callback(total_file_count, file_advancement):
                sync_metadata.update(sync_task, visible=False)
                overall_progress.update(overall_task,
                                        total=total_file_count,
                                        advance=file_advancement,
                                        visible=True)

            def file_upload_callback(file_name, file_total_bytes,
                                     file_bytes_sent):
                if file_name not in file_tasks:
                    file_tasks[file_name] = file_progress.add_task(
                        f"[blue]{file_name}",
                        filename=file_name,
                        total=file_total_bytes)

                # Rich has a concurrency issue, so sometimes updating progress
                # or removing a task fails. Wrapping this logic around a try/catch block
                # is a workaround, we should consider solving this properly (e.g.: using locks)
                try:
                    file_progress.update(file_tasks[file_name],
                                         completed=file_bytes_sent)

                    for task in file_progress.tasks:
                        if task.finished and len(
                                file_progress.tasks) >= max_workers:
                            file_progress.remove_task(task.id)
                except Exception as e:
                    pass

            upload_manager = dataset.push(
                files,
                files_to_exclude=files_to_exclude,
                fps=fps,
                as_frames=frames,
                path=path,
                preserve_folders=preserve_folders,
                progress_callback=progress_callback,
                file_upload_callback=file_upload_callback,
            )
        console = Console(theme=_console_theme())

        console.print()

        if not upload_manager.blocked_count and not upload_manager.error_count:
            console.print(
                f"All {upload_manager.total_count} files have been successfully uploaded.\n",
                style="success")
            return

        already_existing_items = []
        other_skipped_items = []
        for item in upload_manager.blocked_items:
            if item.reason == "ALREADY_EXISTS":
                already_existing_items.append(item)
            else:
                other_skipped_items.append(item)

        if already_existing_items:
            console.print(
                f"Skipped {len(already_existing_items)} files already in the dataset.\n",
                style="warning",
            )

        if upload_manager.error_count or other_skipped_items:
            error_count = upload_manager.error_count + len(other_skipped_items)
            console.print(
                f"{error_count} files couldn't be uploaded because an error occurred.\n",
                style="error",
            )

        if not verbose and upload_manager.error_count:
            console.print('Re-run with "--verbose" for further details')
            return

        error_table: Table = Table("Dataset Item ID",
                                   "Filename",
                                   "Remote Path",
                                   "Stage",
                                   "Reason",
                                   show_header=True,
                                   header_style="bold cyan")

        for item in upload_manager.blocked_items:
            if item.reason != "ALREADY_EXISTS":
                error_table.add_row(str(item.dataset_item_id), item.filename,
                                    item.path, "UPLOAD_REQUEST", item.reason)

        for error in upload_manager.errors:
            for local_file in upload_manager.local_files:
                if local_file.local_path != error.file_path:
                    continue

                for pending_item in upload_manager.pending_items:
                    if pending_item.filename != local_file.data["filename"]:
                        continue

                    error_table.add_row(
                        str(pending_item.dataset_item_id),
                        pending_item.filename,
                        pending_item.path,
                        error.stage.name,
                        str(error.error),
                    )
                    break

        if error_table.row_count:
            console.print(error_table)
        print_new_version_info(client)
    except NotFound as e:
        _error(f"No dataset with name '{e.name}'")
    except UnsupportedFileType as e:
        _error(f"Unsupported file type {e.path.suffix} ({e.path.name})")
    except ValueError:
        _error(f"No files found")