Beispiel #1
0
def test_overwrite_multiple_progress_bars_with_section_outputs(ansi_io):
    output1 = ansi_io.section()
    output2 = ansi_io.section()

    bar1 = ProgressBar(output1, 50, 0)
    bar2 = ProgressBar(output2, 50, 0)

    bar1.start()
    bar2.start()

    bar2.advance()
    bar1.advance()

    output = [
        "  0/50 [>---------------------------]   0%",
        "  0/50 [>---------------------------]   0%",
        "\x1b[1A\x1b[0J  1/50 [>---------------------------]   2%",
        "\x1b[2A\x1b[0J  1/50 [>---------------------------]   2%",
        "\x1b[1A\x1b[0J  1/50 [>---------------------------]   2%",
        "  1/50 [>---------------------------]   2%",
    ]

    expected = "\n".join(output) + "\n"

    assert expected == ansi_io.fetch_error()
Beispiel #2
0
def test_clear(ansi_io):
    bar = ProgressBar(ansi_io, 50, 0)
    bar.start()
    bar.set_progress(25)
    bar.clear()

    output = [
        "  0/50 [>---------------------------]   0%",
        " 25/50 [==============>-------------]  50%",
        "                                          ",
    ]

    expected = "\x0D" + "\x0D".join(output)

    assert expected == ansi_io.fetch_error()
Beispiel #3
0
def test_customizations(ansi_io):
    bar = ProgressBar(ansi_io, 10, 0)
    bar.set_bar_width(10)
    bar.set_bar_character("_")
    bar.set_empty_bar_character(" ")
    bar.set_progress_character("/")
    bar.set_format(" %current%/%max% [%bar%] %percent:3s%%")
    bar.start()
    bar.advance()

    output = ["  0/10 [/         ]   0%", "  1/10 [_/        ]  10%"]

    expected = "\x0D" + "\x0D".join(output)

    assert expected == ansi_io.fetch_error()
Beispiel #4
0
def test_overwrite_with_section_output(ansi_io):
    bar = ProgressBar(ansi_io.section(), 50, 0)
    bar.start()
    bar.display()
    bar.advance()
    bar.advance()

    output = [
        "  0/50 [>---------------------------]   0%",
        "  0/50 [>---------------------------]   0%",
        "  1/50 [>---------------------------]   2%",
        "  2/50 [=>--------------------------]   4%",
    ]

    expected = "\n\x1b[1A\x1b[0J".join(output) + "\n"

    assert expected == ansi_io.fetch_error()
Beispiel #5
0
def test_percent_not_hundred_before_complete(ansi_io):
    bar = ProgressBar(ansi_io, 200, 0)
    bar.start()
    bar.display()
    bar.advance(199)
    bar.advance()

    output = [
        "   0/200 [>---------------------------]   0%",
        "   0/200 [>---------------------------]   0%",
        " 199/200 [===========================>]  99%",
        " 200/200 [============================] 100%",
    ]

    expected = "\x0D" + "\x0D".join(output)

    assert expected == ansi_io.fetch_error()
Beispiel #6
0
def test_percent(ansi_io):
    bar = ProgressBar(ansi_io, 50, 0)
    bar.start()
    bar.display()
    bar.advance()
    bar.advance()

    output = [
        "  0/50 [>---------------------------]   0%",
        "  0/50 [>---------------------------]   0%",
        "  1/50 [>---------------------------]   2%",
        "  2/50 [=>--------------------------]   4%",
    ]

    expected = "\x0D" + "\x0D".join(output)

    assert expected == ansi_io.fetch_error()
Beispiel #7
0
def test_multiline_format(ansi_io):
    bar = ProgressBar(ansi_io, 3, 0)
    bar.set_format("%bar%\nfoobar")

    bar.start()
    bar.advance()
    bar.clear()
    bar.finish()

    output = [
        "\033[1A>---------------------------\nfoobar",
        "\033[1A=========>------------------\nfoobar                      ",
        "\033[1A                            \n                            ",
        "\033[1A============================\nfoobar                      ",
    ]

    expected = "\x0D" + "\x0D".join(output)

    assert expected == ansi_io.fetch_error()
Beispiel #8
0
def test_set_current_progress(ansi_io):
    bar = ProgressBar(ansi_io, 50, 0)
    bar.start()
    bar.display()
    bar.advance()
    bar.set_progress(15)
    bar.set_progress(25)

    output = [
        "  0/50 [>---------------------------]   0%",
        "  0/50 [>---------------------------]   0%",
        "  1/50 [>---------------------------]   2%",
        " 15/50 [========>-------------------]  30%",
        " 25/50 [==============>-------------]  50%",
    ]

    expected = "\x0D" + "\x0D".join(output)

    assert expected == ansi_io.fetch_error()
Beispiel #9
0
def test_overwrite_with_shorter_line(ansi_io):
    bar = ProgressBar(ansi_io, 50, 0)
    bar.set_format(" %current%/%max% [%bar%] %percent:3s%%")
    bar.start()
    bar.display()
    bar.advance()

    # Set shorter format
    bar.set_format(" %current%/%max% [%bar%]")
    bar.advance()

    output = [
        "  0/50 [>---------------------------]   0%",
        "  0/50 [>---------------------------]   0%",
        "  1/50 [>---------------------------]   2%",
        "  2/50 [=>--------------------------]     ",
    ]

    expected = "\x0D" + "\x0D".join(output)

    assert expected == ansi_io.fetch_error()
Beispiel #10
0
def test_format(ansi_io):
    output = [
        "  0/10 [>---------------------------]   0%",
        " 10/10 [============================] 100%",
        " 10/10 [============================] 100%",
    ]

    expected = "\x0D" + "\x0D".join(output)

    # max in construct, no format
    ansi_io.clear_error()
    bar = ProgressBar(ansi_io, 10)
    bar.start()
    bar.advance(10)
    bar.finish()

    assert expected == ansi_io.fetch_error()

    # max in start, no format
    ansi_io.clear_error()
    bar = ProgressBar(ansi_io)
    bar.start(10)
    bar.advance(10)
    bar.finish()

    assert expected == ansi_io.fetch_error()

    # max in construct, explicit format before
    ansi_io.clear_error()
    bar = ProgressBar(ansi_io, 10)
    bar.set_format("normal")
    bar.start()
    bar.advance(10)
    bar.finish()

    assert expected == ansi_io.fetch_error()

    # max in start, explicit format before
    ansi_io.clear_error()
    bar = ProgressBar(ansi_io)
    bar.set_format("normal")
    bar.start(10)
    bar.advance(10)
    bar.finish()

    assert expected == ansi_io.fetch_error()
Beispiel #11
0
def test_non_decorated_output(io):
    bar = ProgressBar(io, 200, 0)
    bar.start()

    for i in range(200):
        bar.advance()

    bar.finish()

    expected = "\n".join([
        "   0/200 [>---------------------------]   0%",
        "  20/200 [==>-------------------------]  10%",
        "  40/200 [=====>----------------------]  20%",
        "  60/200 [========>-------------------]  30%",
        "  80/200 [===========>----------------]  40%",
        " 100/200 [==============>-------------]  50%",
        " 120/200 [================>-----------]  60%",
        " 140/200 [===================>--------]  70%",
        " 160/200 [======================>-----]  80%",
        " 180/200 [=========================>--]  90%",
        " 200/200 [============================] 100%",
    ])

    assert expected == io.fetch_error()
Beispiel #12
0
        async def go():
            io = ConsoleIO()
            io.error_line("<info>Sampling...</info>")
            progress_bar = ProgressBar(io)
            progress_bar.set_format("very_verbose")

            current_and_max_iterations_re = re.compile(
                r"Iteration:\s+(\d+)\s+/\s+(\d+)")
            async with stan.common.HttpstanClient() as client:
                operations = []
                for payload in payloads:
                    resp = await client.post(f"/{self.model_name}/fits",
                                             json=payload)
                    if resp.status == 422:
                        raise ValueError(str(resp.json()))
                    elif resp.status != 201:
                        raise RuntimeError(resp.json()["message"])
                    assert resp.status == 201
                    operations.append(resp.json())

                # poll to get progress for each chain until all chains finished
                current_iterations = {}
                while not all(operation["done"] for operation in operations):
                    for operation in operations:
                        if operation["done"]:
                            continue
                        resp = await client.get(f"/{operation['name']}")
                        assert resp.status != 404
                        operation.update(resp.json())
                        progress_message = operation["metadata"].get(
                            "progress")
                        if not progress_message:
                            continue
                        iteration, iteration_max = map(
                            int,
                            current_and_max_iterations_re.findall(
                                progress_message).pop(0))
                        if not progress_bar.get_max_steps(
                        ):  # i.e., has not started
                            progress_bar.start(max=iteration_max * num_chains)
                        current_iterations[operation["name"]] = iteration
                        progress_bar.set_progress(
                            sum(current_iterations.values()))
                    await asyncio.sleep(0.01)
                # Sampling has finished. But we do not call `progress_bar.finish()` right
                # now. First we write informational messages to the screen, then we
                # redraw the (complete) progress bar. Only after that do we call `finish`.

                stan_outputs = []
                for operation in operations:
                    fit_name = operation["result"].get("name")
                    if fit_name is None:  # operation["result"] is an error
                        assert not str(operation["result"]["code"]).startswith(
                            "2"), operation
                        raise RuntimeError(operation["result"]["message"])
                    resp = await client.get(f"/{fit_name}")
                    if resp.status != 200:
                        raise RuntimeError((resp.json())["message"])
                    stan_outputs.append(resp.content)

                    # clean up after ourselves when fit is uncacheable (no random seed)
                    if self.random_seed is None:
                        resp = await client.delete(f"/{fit_name}")
                        if resp.status not in {200, 202, 204}:
                            raise RuntimeError((resp.json())["message"])

            stan_outputs = tuple(
                stan_outputs)  # Fit constructor expects a tuple.

            def is_nonempty_logger_message(msg: simdjson.Object):
                return msg["topic"] == "logger" and msg["values"][0] != "info:"

            def is_iteration_or_elapsed_time_logger_message(
                    msg: simdjson.Object):
                # Assumes `msg` is a message with topic `logger`.
                text = msg["values"][0]
                return (
                    text.startswith("info:Iteration:")
                    or text.startswith("info: Elapsed Time:")
                    # this detects lines following "Elapsed Time:", part of a multi-line Stan message
                    or text.startswith("info:" + " " * 15))

            parser = simdjson.Parser()
            nonstandard_logger_messages = []
            for stan_output in stan_outputs:
                for line in stan_output.splitlines():
                    # Do not attempt to parse non-logger messages. Draws could contain nan or inf values.
                    # simdjson cannot parse lines containing such values.
                    if b'"logger"' not in line:
                        continue
                    msg = parser.parse(line)
                    if is_nonempty_logger_message(
                            msg
                    ) and not is_iteration_or_elapsed_time_logger_message(msg):
                        nonstandard_logger_messages.append(msg.as_dict())
            del parser  # simdjson.Parser is no longer used at this point.

            progress_bar.clear()
            io.error("\x08" * progress_bar._last_messages_length
                     )  # move left to start of line
            if nonstandard_logger_messages:
                io.error_line(
                    "<comment>Messages received during sampling:</comment>")
                for msg in nonstandard_logger_messages:
                    text = msg["values"][0].replace("info:", "  ").replace(
                        "error:", "  ")
                    if text.strip():
                        io.error_line(f"{text}")
            progress_bar.display()  # re-draw the (complete) progress bar
            progress_bar.finish()
            io.error_line("\n<info>Done.</info>")

            fit = stan.fit.Fit(
                stan_outputs,
                num_chains,
                self.param_names,
                self.constrained_param_names,
                self.dims,
                num_warmup,
                num_samples,
                num_thin,
                save_warmup,
            )

            for entry_point in stan.plugins.get_plugins():
                Plugin = entry_point.load()
                fit = Plugin().on_post_fit(fit)
            return fit