示例#1
0
def test_clear(ansi_io):
    bar = ProgressBar(ansi_io, 50, 0)
    bar.start()
    bar.set_progress(25)
    bar.clear()

    output = [
        "  0/50 [>---------------------------]   0%",
        " 25/50 [==============>-------------]  50%",
        "                                          ",
    ]

    expected = "\x0D" + "\x0D".join(output)

    assert expected == ansi_io.fetch_error()
示例#2
0
def test_multiline_format(ansi_io):
    bar = ProgressBar(ansi_io, 3, 0)
    bar.set_format("%bar%\nfoobar")

    bar.start()
    bar.advance()
    bar.clear()
    bar.finish()

    output = [
        "\033[1A>---------------------------\nfoobar",
        "\033[1A=========>------------------\nfoobar                      ",
        "\033[1A                            \n                            ",
        "\033[1A============================\nfoobar                      ",
    ]

    expected = "\x0D" + "\x0D".join(output)

    assert expected == ansi_io.fetch_error()
示例#3
0
        async def go():
            io = ConsoleIO()
            io.error_line("<info>Sampling...</info>")
            progress_bar = ProgressBar(io)
            progress_bar.set_format("very_verbose")

            current_and_max_iterations_re = re.compile(
                r"Iteration:\s+(\d+)\s+/\s+(\d+)")
            async with stan.common.HttpstanClient() as client:
                operations = []
                for payload in payloads:
                    resp = await client.post(f"/{self.model_name}/fits",
                                             json=payload)
                    if resp.status == 422:
                        raise ValueError(str(resp.json()))
                    elif resp.status != 201:
                        raise RuntimeError(resp.json()["message"])
                    assert resp.status == 201
                    operations.append(resp.json())

                # poll to get progress for each chain until all chains finished
                current_iterations = {}
                while not all(operation["done"] for operation in operations):
                    for operation in operations:
                        if operation["done"]:
                            continue
                        resp = await client.get(f"/{operation['name']}")
                        assert resp.status != 404
                        operation.update(resp.json())
                        progress_message = operation["metadata"].get(
                            "progress")
                        if not progress_message:
                            continue
                        iteration, iteration_max = map(
                            int,
                            current_and_max_iterations_re.findall(
                                progress_message).pop(0))
                        if not progress_bar.get_max_steps(
                        ):  # i.e., has not started
                            progress_bar.start(max=iteration_max * num_chains)
                        current_iterations[operation["name"]] = iteration
                        progress_bar.set_progress(
                            sum(current_iterations.values()))
                    await asyncio.sleep(0.01)
                # Sampling has finished. But we do not call `progress_bar.finish()` right
                # now. First we write informational messages to the screen, then we
                # redraw the (complete) progress bar. Only after that do we call `finish`.

                stan_outputs = []
                for operation in operations:
                    fit_name = operation["result"].get("name")
                    if fit_name is None:  # operation["result"] is an error
                        assert not str(operation["result"]["code"]).startswith(
                            "2"), operation
                        raise RuntimeError(operation["result"]["message"])
                    resp = await client.get(f"/{fit_name}")
                    if resp.status != 200:
                        raise RuntimeError((resp.json())["message"])
                    stan_outputs.append(resp.content)

                    # clean up after ourselves when fit is uncacheable (no random seed)
                    if self.random_seed is None:
                        resp = await client.delete(f"/{fit_name}")
                        if resp.status not in {200, 202, 204}:
                            raise RuntimeError((resp.json())["message"])

            stan_outputs = tuple(
                stan_outputs)  # Fit constructor expects a tuple.

            def is_nonempty_logger_message(msg: simdjson.Object):
                return msg["topic"] == "logger" and msg["values"][0] != "info:"

            def is_iteration_or_elapsed_time_logger_message(
                    msg: simdjson.Object):
                # Assumes `msg` is a message with topic `logger`.
                text = msg["values"][0]
                return (
                    text.startswith("info:Iteration:")
                    or text.startswith("info: Elapsed Time:")
                    # this detects lines following "Elapsed Time:", part of a multi-line Stan message
                    or text.startswith("info:" + " " * 15))

            parser = simdjson.Parser()
            nonstandard_logger_messages = []
            for stan_output in stan_outputs:
                for line in stan_output.splitlines():
                    # Do not attempt to parse non-logger messages. Draws could contain nan or inf values.
                    # simdjson cannot parse lines containing such values.
                    if b'"logger"' not in line:
                        continue
                    msg = parser.parse(line)
                    if is_nonempty_logger_message(
                            msg
                    ) and not is_iteration_or_elapsed_time_logger_message(msg):
                        nonstandard_logger_messages.append(msg.as_dict())
            del parser  # simdjson.Parser is no longer used at this point.

            progress_bar.clear()
            io.error("\x08" * progress_bar._last_messages_length
                     )  # move left to start of line
            if nonstandard_logger_messages:
                io.error_line(
                    "<comment>Messages received during sampling:</comment>")
                for msg in nonstandard_logger_messages:
                    text = msg["values"][0].replace("info:", "  ").replace(
                        "error:", "  ")
                    if text.strip():
                        io.error_line(f"{text}")
            progress_bar.display()  # re-draw the (complete) progress bar
            progress_bar.finish()
            io.error_line("\n<info>Done.</info>")

            fit = stan.fit.Fit(
                stan_outputs,
                num_chains,
                self.param_names,
                self.constrained_param_names,
                self.dims,
                num_warmup,
                num_samples,
                num_thin,
                save_warmup,
            )

            for entry_point in stan.plugins.get_plugins():
                Plugin = entry_point.load()
                fit = Plugin().on_post_fit(fit)
            return fit