Exemplo n.º 1
0
        async def go(num_chains):
            host, port = 'localhost', 8080
            path = f'/v1/models/{self.model_id}/actions'
            stan_outputs = []
            for chain in range(1, num_chains + 1):
                payload = {
                    'type': 'stan::services::sample::hmc_nuts_diag_e',
                }
                payload.update(kwargs)
                payload['chain'] = chain
                payload['data'] = self.data

                # fit needs to know num_samples, num_warmup, num_thin, save_warmup
                # progress bar needs to know some of these
                num_warmup = payload.get('num_warmup',
                                         arguments.lookup_default(arguments.Method['SAMPLE'], 'num_warmup'))
                num_samples = payload.get('num_samples',
                                          arguments.lookup_default(arguments.Method['SAMPLE'], 'num_samples'))
                num_thin = payload.get('num_thin',
                                       arguments.lookup_default(arguments.Method['SAMPLE'], 'num_thin'))
                save_warmup = payload.get('save_warmup',
                                          arguments.lookup_default(arguments.Method['SAMPLE'], 'save_warmup'))
                pbar_total = num_samples + num_warmup * int(save_warmup)
                stan_output = []
                async for payload_response in httpstan_helpers.post_aiter(host, port, path, payload):
                    with tqdm.tqdm(total=pbar_total) as pbar:
                        stan_output.append(payload_response)
                        if payload_response['topic'] == 'SAMPLE':
                            pbar.update()
                stan_outputs.append(stan_output)
            return pystan.fit.Fit(stan_outputs, num_chains, self.param_names, self.constrained_param_names, self.dims,
                                  num_warmup, num_samples, num_thin, save_warmup)
Exemplo n.º 2
0
def test_lookup_default():
    """Test argument default value lookup."""
    expected = 1000
    assert expected == arguments.lookup_default(arguments.Method.SAMPLE,
                                                "num_samples")
    expected = 0.05
    assert expected == arguments.lookup_default(arguments.Method.SAMPLE,
                                                "gamma")
Exemplo n.º 3
0
        def go(num_chains):
            host, port, path = 'localhost', 8080, f'/v1/models/{self.model_id}/actions'
            stan_outputs = []
            for chain in range(1, num_chains + 1):
                payload = {
                    'type': 'stan::services::sample::hmc_nuts_diag_e',
                }
                payload.update(kwargs)
                payload['chain'] = chain
                payload['data'] = self.data

                # fit needs to know num_samples, num_warmup, num_thin, save_warmup
                # progress bar needs to know some of these
                num_warmup = payload.get(
                    'num_warmup',
                    arguments.lookup_default(arguments.Method['SAMPLE'],
                                             'num_warmup'))
                num_samples = payload.get(
                    'num_samples',
                    arguments.lookup_default(arguments.Method['SAMPLE'],
                                             'num_samples'))
                num_thin = payload.get(
                    'num_thin',
                    arguments.lookup_default(arguments.Method['SAMPLE'],
                                             'num_thin'))
                save_warmup = payload.get(
                    'save_warmup',
                    arguments.lookup_default(arguments.Method['SAMPLE'],
                                             'save_warmup'))
                pbar_total = num_samples + num_warmup * int(save_warmup)
                stan_output = []

                with tqdm.tqdm(total=pbar_total) as pbar:
                    with requests.post(f'http://{host}:{port}{path}',
                                       json=payload,
                                       stream=True) as r:
                        for line in r.iter_lines():
                            payload_response = json.loads(line)
                            stan_output.append(payload_response)
                            if payload_response['topic'] == 'SAMPLE':
                                pbar.update()
                stan_outputs.append(stan_output)
            return pystan.fit.Fit(stan_outputs, num_chains, self.param_names,
                                  self.constrained_param_names, self.dims,
                                  num_warmup, num_samples, num_thin,
                                  save_warmup)
Exemplo n.º 4
0
async def call(function_name: str, program_module, data: dict, **kwargs):
    """Call stan::services function.

    Yields (asynchronously) messages from the stan::callbacks writers which are
    written to by the stan::services function.

    This is a coroutine function.

    Arguments:
        function_name: full name of function in stan::services
        program_module (module): Stan Program extension module
        data: dictionary with data with which to populate array_var_context
        kwargs: named stan::services function arguments, see CmdStan documentation.
    """
    method, function_basename = function_name.replace('stan::services::',
                                                      '').split('::', 1)
    queue_wrapper = httpstan.spsc_queue.SPSCQueue(capacity=10000)
    array_var_context_capsule = httpstan.stan.make_array_var_context(data)
    function_wrapper = getattr(program_module, function_basename + '_wrapper')

    # fetch defaults for missing arguments
    function_arguments = arguments.function_arguments(function_basename,
                                                      program_module)
    # This is clumsy due to the way default values are available. There is no
    # way to directly lookup the default value for an argument (e.g., `delta`)
    # given both the argument name and the (full) function name (e.g.,
    # `stan::services::hmc_nuts_diag_e_adapt`).
    for arg in function_arguments:
        if arg not in kwargs:
            kwargs[arg] = arguments.lookup_default(
                arguments.Method[method.upper()], arg)
    function_wrapper_partial = functools.partial(function_wrapper,
                                                 array_var_context_capsule,
                                                 queue_wrapper.to_capsule(),
                                                 **kwargs)

    # WISHLIST: can one use ProcessPoolExecutor somehow on Linux and OSX?
    loop = asyncio.get_event_loop()
    future = asyncio.ensure_future(
        loop.run_in_executor(None, function_wrapper_partial))  # type: ignore
    parser = httpstan.callbacks_writer_parser.WriterParser()
    while True:
        try:
            message = queue_wrapper.get_nowait()
        except queue.Empty:
            if future.done():
                break
            await asyncio.sleep(0.1)
            continue
        parsed = parser.parse(message.decode())
        # parsed is None if the message was a blank line or a header with param names
        if parsed:
            yield parsed
    future.result()  # raises exceptions from task, if any
Exemplo n.º 5
0
    def _create_fit(self, payload: dict) -> stan.fit.Fit:
        """Make a request to httpstan's ``create_fit`` endpoint and process results.

        Users should not use this function.

        Arguments:
            payload: dict whose JSON-encoded contents will be sent as the request body.

        Returns:
            Fit: instance of Fit allowing access to draws.

        """
        assert "chain" not in payload, "`chain` id is set automatically."
        assert "data" not in payload, "`data` is set in `build`."
        assert "random_seed" not in payload, "`random_seed` is set in `build`."
        assert "function" in payload

        num_chains = payload.pop("num_chains", 4)

        init = payload.pop("init", [dict() for _ in range(num_chains)])
        if len(init) != num_chains:
            raise ValueError("Initial values must be provided for each chain.")

        payloads = []
        for chain in range(1, num_chains + 1):
            payload["chain"] = chain  # type: ignore
            payload["data"] = self.data  # type: ignore
            payload["init"] = init.pop(0)
            if self.random_seed is not None:
                payload["random_seed"] = self.random_seed  # type: ignore

            # fit needs to know num_samples, num_warmup, num_thin, save_warmup
            # progress bar needs to know some of these
            num_warmup = payload.get(
                "num_warmup",
                arguments.lookup_default(arguments.Method["SAMPLE"],
                                         "num_warmup"))
            num_samples = payload.get(
                "num_samples",
                arguments.lookup_default(arguments.Method["SAMPLE"],
                                         "num_samples"),
            )
            num_thin = payload.get(
                "num_thin",
                arguments.lookup_default(arguments.Method["SAMPLE"],
                                         "num_thin"))
            save_warmup = payload.get(
                "save_warmup",
                arguments.lookup_default(arguments.Method["SAMPLE"],
                                         "save_warmup"),
            )
            payloads.append(payload)

        async def go():
            io = ConsoleIO()
            io.error_line("<info>Sampling...</info>")
            progress_bar = ProgressBar(io)
            progress_bar.set_format("very_verbose")

            current_and_max_iterations_re = re.compile(
                r"Iteration:\s+(\d+)\s+/\s+(\d+)")
            async with stan.common.HttpstanClient() as client:
                operations = []
                for payload in payloads:
                    resp = await client.post(f"/{self.model_name}/fits",
                                             json=payload)
                    if resp.status == 422:
                        raise ValueError(str(resp.json()))
                    elif resp.status != 201:
                        raise RuntimeError(resp.json()["message"])
                    assert resp.status == 201
                    operations.append(resp.json())

                # poll to get progress for each chain until all chains finished
                current_iterations = {}
                while not all(operation["done"] for operation in operations):
                    for operation in operations:
                        if operation["done"]:
                            continue
                        resp = await client.get(f"/{operation['name']}")
                        assert resp.status != 404
                        operation.update(resp.json())
                        progress_message = operation["metadata"].get(
                            "progress")
                        if not progress_message:
                            continue
                        iteration, iteration_max = map(
                            int,
                            current_and_max_iterations_re.findall(
                                progress_message).pop(0))
                        if not progress_bar.get_max_steps(
                        ):  # i.e., has not started
                            progress_bar.start(max=iteration_max * num_chains)
                        current_iterations[operation["name"]] = iteration
                        progress_bar.set_progress(
                            sum(current_iterations.values()))
                    await asyncio.sleep(0.01)
                # Sampling has finished. But we do not call `progress_bar.finish()` right
                # now. First we write informational messages to the screen, then we
                # redraw the (complete) progress bar. Only after that do we call `finish`.

                stan_outputs = []
                for operation in operations:
                    fit_name = operation["result"].get("name")
                    if fit_name is None:  # operation["result"] is an error
                        assert not str(operation["result"]["code"]).startswith(
                            "2"), operation
                        raise RuntimeError(operation["result"]["message"])
                    resp = await client.get(f"/{fit_name}")
                    if resp.status != 200:
                        raise RuntimeError((resp.json())["message"])
                    stan_outputs.append(resp.content)

                    # clean up after ourselves when fit is uncacheable (no random seed)
                    if self.random_seed is None:
                        resp = await client.delete(f"/{fit_name}")
                        if resp.status not in {200, 202, 204}:
                            raise RuntimeError((resp.json())["message"])

            stan_outputs = tuple(
                stan_outputs)  # Fit constructor expects a tuple.

            def is_nonempty_logger_message(msg: simdjson.Object):
                return msg["topic"] == "logger" and msg["values"][0] != "info:"

            def is_iteration_or_elapsed_time_logger_message(
                    msg: simdjson.Object):
                # Assumes `msg` is a message with topic `logger`.
                text = msg["values"][0]
                return (
                    text.startswith("info:Iteration:")
                    or text.startswith("info: Elapsed Time:")
                    # this detects lines following "Elapsed Time:", part of a multi-line Stan message
                    or text.startswith("info:" + " " * 15))

            parser = simdjson.Parser()
            nonstandard_logger_messages = []
            for stan_output in stan_outputs:
                for line in stan_output.splitlines():
                    # Do not attempt to parse non-logger messages. Draws could contain nan or inf values.
                    # simdjson cannot parse lines containing such values.
                    if b'"logger"' not in line:
                        continue
                    msg = parser.parse(line)
                    if is_nonempty_logger_message(
                            msg
                    ) and not is_iteration_or_elapsed_time_logger_message(msg):
                        nonstandard_logger_messages.append(msg.as_dict())
            del parser  # simdjson.Parser is no longer used at this point.

            progress_bar.clear()
            io.error("\x08" * progress_bar._last_messages_length
                     )  # move left to start of line
            if nonstandard_logger_messages:
                io.error_line(
                    "<comment>Messages received during sampling:</comment>")
                for msg in nonstandard_logger_messages:
                    text = msg["values"][0].replace("info:", "  ").replace(
                        "error:", "  ")
                    if text.strip():
                        io.error_line(f"{text}")
            progress_bar.display()  # re-draw the (complete) progress bar
            progress_bar.finish()
            io.error_line("\n<info>Done.</info>")

            fit = stan.fit.Fit(
                stan_outputs,
                num_chains,
                self.param_names,
                self.constrained_param_names,
                self.dims,
                num_warmup,
                num_samples,
                num_thin,
                save_warmup,
            )

            for entry_point in stan.plugins.get_plugins():
                Plugin = entry_point.load()
                fit = Plugin().on_post_fit(fit)
            return fit

        try:
            return asyncio.run(go())
        except KeyboardInterrupt:
            return  # type: ignore
Exemplo n.º 6
0
async def call(
    function_name: str,
    model_name: str,
    fit_name: str,
    logger_callback: typing.Optional[typing.Callable] = None,
    **kwargs: dict,
) -> None:
    """Call stan::services function.

    Yields (asynchronously) messages from the stan::callbacks writers which are
    written to by the stan::services function.

    This is a coroutine function.

    Arguments:
        function_name: full name of function in stan::services
        services_module (module): model-specific services extension module
        fit_name: Name of fit, used for saving length-prefixed messages
        logger_callback: Callback function for logger messages, including sampling progress messages
        kwargs: named stan::services function arguments, see CmdStan documentation.
    """
    method, function_basename = function_name.replace("stan::services::",
                                                      "").split("::", 1)

    # Fetch defaults for missing arguments. This is an important step!
    # For example, `random_seed`, if not in `kwargs`, will be set.
    # temporarily load the module to lookup function arguments
    services_module = httpstan.models.import_services_extension_module(
        model_name)
    function_arguments = arguments.function_arguments(function_basename,
                                                      services_module)
    del services_module
    # This is clumsy due to the way default values are available. There is no
    # way to directly lookup the default value for an argument (e.g., `delta`)
    # given both the argument name and the (full) function name (e.g.,
    # `stan::services::hmc_nuts_diag_e_adapt`).
    for arg in function_arguments:
        if arg not in kwargs:
            kwargs[arg] = typing.cast(
                typing.Any,
                arguments.lookup_default(arguments.Method[method.upper()],
                                         arg))

    with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as socket_:
        _, socket_filename = tempfile.mkstemp(prefix="httpstan_",
                                              suffix=".sock")
        os.unlink(socket_filename)
        socket_.bind(socket_filename)
        socket_.listen(
            4)  # three stan callback writers, one stan callback logger

        lazy_function_wrapper = _make_lazy_function_wrapper(
            function_basename, model_name)
        lazy_function_wrapper_partial = functools.partial(
            lazy_function_wrapper, socket_filename, **kwargs)

        # If HTTPSTAN_DEBUG is set block until sampling is complete. Do not use an executor.
        if HTTPSTAN_DEBUG:  # pragma: no cover
            future: asyncio.Future = asyncio.Future()
            logger.debug("Calling stan::services function with debug mode on.")
            print(
                "Warning: httpstan debug mode is on! `num_samples` must be set to a small number (e.g., 10)."
            )
            future.set_result(lazy_function_wrapper_partial())
        else:
            future = asyncio.get_running_loop().run_in_executor(
                executor, lazy_function_wrapper_partial)  # type: ignore

        messages_files: typing.Mapping[socket.socket,
                                       io.BytesIO] = collections.defaultdict(
                                           io.BytesIO)
        # using a wbits value which makes things compatible with gzip
        messages_compressobjs: typing.Mapping[
            socket.socket, zlib._Compress] = collections.defaultdict(
                functools.partial(zlib.compressobj,
                                  level=zlib.Z_BEST_SPEED,
                                  wbits=zlib.MAX_WBITS | 16))
        potential_readers = [socket_]

        while True:
            # note: timeout of 0.01 seems to work well based on measurements
            readable, writeable, errored = select.select(
                potential_readers, [], [], 0.01)
            for s in readable:
                if s is socket_:
                    conn, _ = s.accept()
                    logger.debug(
                        "Opened socket connection to a socket_logger or socket_writer."
                    )
                    potential_readers.append(conn)
                    continue
                message = s.recv(8192)
                if not len(message):
                    # `close` called on other end
                    s.close()
                    logger.debug(
                        "Closed socket connection to a socket_logger or socket_writer."
                    )
                    potential_readers.remove(s)
                    continue
                # Only trigger callback if message has topic `logger`.
                if logger_callback and b'"logger"' in message:
                    logger_callback(message)
                messages_files[s].write(
                    messages_compressobjs[s].compress(message))
            # if `potential_readers == [socket_]` then either (1) no connections
            # have been opened or (2) all connections have been closed.
            if not readable:
                if potential_readers == [socket_] and future.done():
                    logger.debug(
                        f"Stan services function `{function_basename}` returned without problems or raised a C++ exception."
                    )
                    break
                # no messages right now and not done. Sleep briefly so other pending tasks get a chance to run.
                await asyncio.sleep(0.001)

    compressed_parts = []
    for s, fh in messages_files.items():
        fh.write(messages_compressobjs[s].flush())
        fh.flush()
        compressed_parts.append(fh.getvalue())
        fh.close()
    httpstan.cache.dump_fit(b"".join(compressed_parts), fit_name)

    # if an exception has already occurred, grab relevant info messages, add as context
    exception = future.exception()
    if exception and len(exception.args) == 1:
        import gzip
        import json

        original_exception_message = exception.args[
            0]  # e.g., from ValueError("Initialization failed.")
        info_messages_for_context = []
        num_context_messages = 4

        jsonlines = gzip.decompress(b"".join(compressed_parts)).decode()
        for line in jsonlines.split("\n")[:num_context_messages]:
            try:
                message = json.loads(line)
                info_message = message["values"].pop().replace("info:", "")
                info_messages_for_context.append(info_message.strip())
            except json.JSONDecodeError:
                pass
        # add the info messages to the original exception message. For example,
        # ValueError("Initialization failed.") -> ValueError("Initialization failed. Rejecting initial value: Log probability ...")
        if info_messages_for_context:
            new_exception_message = f"{original_exception_message} {' '.join(info_messages_for_context)} ..."
            exception.args = (new_exception_message, )

    # `result()` method will raise exceptions, if any
    future.result()
Exemplo n.º 7
0
def test_lookup_invalid() -> None:
    """Test argument default value lookup with invalid argument."""
    with pytest.raises(ValueError, match=r"No argument `.*` is associated with `.*`\."):
        arguments.lookup_default(arguments.Method.SAMPLE, "invalid_argument")
Exemplo n.º 8
0
def test_lookup_default(argument_value: Tuple[str, Any]) -> None:
    """Test argument default value lookup."""
    arg, value = argument_value
    assert value == arguments.lookup_default(arguments.Method.SAMPLE, arg)
Exemplo n.º 9
0
    def sample(self, **kwargs):
        """Draw samples from the model.

        Parameters in ``kwargs`` will be passed to the default sample function in
        stan::services. Parameter names are identical to those used in CmdStan.
        See the CmdStan documentation for parameter descriptions and default
        values.

        Returns:
            Fit: instance of Fit allowing access to draws.

        """
        assert isinstance(self.data, dict)
        assert "chain" not in kwargs, "`chain` id is set automatically."
        assert "data" not in kwargs, "`data` is set in `build`."
        assert "random_seed" not in kwargs, "`random_seed` is set in `build`."
        num_chains = kwargs.pop("num_chains", 1)

        with stan.common.httpstan_server() as server:
            host, port = server.host, server.port
            stan_outputs = [[] for _ in range(num_chains)]
            payloads = []
            for chain in range(1, num_chains + 1):
                payload = {"function": "stan::services::sample::hmc_nuts_diag_e_adapt"}
                payload.update(kwargs)
                payload["chain"] = chain
                payload["data"] = self.data
                if self.random_seed is not None:
                    payload["random_seed"] = self.random_seed

                # fit needs to know num_samples, num_warmup, num_thin, save_warmup
                # progress bar needs to know some of these
                num_warmup = payload.get(
                    "num_warmup", arguments.lookup_default(arguments.Method["SAMPLE"], "num_warmup")
                )
                num_samples = payload.get(
                    "num_samples",
                    arguments.lookup_default(arguments.Method["SAMPLE"], "num_samples"),
                )
                num_thin = payload.get(
                    "num_thin", arguments.lookup_default(arguments.Method["SAMPLE"], "num_thin")
                )
                save_warmup = payload.get(
                    "save_warmup",
                    arguments.lookup_default(arguments.Method["SAMPLE"], "save_warmup"),
                )
                payloads.append(payload)

            # WISHLIST(AR): rewriting this function in Cython (in httpstan) might speed things up
            def extract_protobuf_messages(fit_bytes):
                varint_decoder = google.protobuf.internal.decoder._DecodeVarint32
                next_pos, pos = 0, 0
                while pos < len(fit_bytes):
                    msg = callbacks_writer_pb2.WriterMessage()
                    next_pos, pos = varint_decoder(fit_bytes, pos)
                    msg.ParseFromString(fit_bytes[pos : pos + next_pos])
                    yield msg
                    pos += next_pos

            fits_url = f"http://{host}:{port}/v1/{self.model_name}/fits"
            operations = []
            for payload in payloads:
                r = requests.post(fits_url, json=payload)
                if r.status_code != 201:
                    raise RuntimeError(r.json()["message"])
                assert r.status_code == 201, r.status_code
                operations.append(r.json())

            while not all(operation["done"] for operation in operations):
                for operation in operations:
                    operation_name = operation["name"]
                    operation.update(
                        requests.get(f"http://{host}:{port}/v1/{operation_name}").json()
                    )
                time.sleep(0.1)

            stan_outputs = []
            for operation in operations:
                fit_name = operation["result"].get("name")
                if fit_name is None:  # operation["result"] is an error
                    assert not str(operation["result"]["code"]).startswith("2"), operation
                    raise RuntimeError(operation["result"]["message"])
                r = requests.get(f"http://{host}:{port}/v1/{fit_name}", json=payload)
                if r.status_code != 200:
                    response_payload = r.json()
                    assert "message" in response_payload, response_payload
                    raise RuntimeError(response_payload["message"])
                stan_outputs.append(tuple(extract_protobuf_messages(r.content)))
            for stan_output in stan_outputs:
                assert isinstance(stan_output, tuple), stan_output
        return stan.fit.Fit(
            stan_outputs,
            num_chains,
            self.param_names,
            self.constrained_param_names,
            self.dims,
            num_warmup,
            num_samples,
            num_thin,
            save_warmup,
        )
Exemplo n.º 10
0
    def _create_fit(self, *, function, num_chains, **kwargs) -> stan.fit.Fit:
        """Make a request to httpstan's ``create_fit`` endpoint and process results.

        Users should not use this function.

        Parameters in ``kwargs`` will be passed to the (Python wrapper of)
        `function`. Parameter names are identical to those used in CmdStan.
        See the CmdStan documentation for parameter descriptions and default
        values.

        Returns:
            Fit: instance of Fit allowing access to draws.

        """
        assert "chain" not in kwargs, "`chain` id is set automatically."
        assert "data" not in kwargs, "`data` is set in `build`."
        assert "random_seed" not in kwargs, "`random_seed` is set in `build`."

        # copy kwargs and verify everything is JSON-encodable
        kwargs = json.loads(DataJSONEncoder().encode(kwargs))

        # FIXME: special handling here for `init`, consistent with PyStan 2 but needs docs
        init: List[Data] = kwargs.pop("init",
                                      [dict() for _ in range(num_chains)])
        if len(init) != num_chains:
            raise ValueError("Initial values must be provided for each chain.")

        payloads = []
        for chain in range(1, num_chains + 1):
            payload = kwargs.copy()
            payload["function"] = function
            payload["chain"] = chain  # type: ignore
            payload["data"] = self.data  # type: ignore
            payload["init"] = init.pop(0)
            if self.random_seed is not None:
                payload["random_seed"] = self.random_seed  # type: ignore

            # fit needs to know num_samples, num_warmup, num_thin, save_warmup
            # progress reporting needs to know some of these
            num_warmup = payload.get(
                "num_warmup",
                arguments.lookup_default(arguments.Method["SAMPLE"],
                                         "num_warmup"))
            num_samples = payload.get(
                "num_samples",
                arguments.lookup_default(arguments.Method["SAMPLE"],
                                         "num_samples"),
            )
            num_thin = payload.get(
                "num_thin",
                arguments.lookup_default(arguments.Method["SAMPLE"],
                                         "num_thin"))
            save_warmup = payload.get(
                "save_warmup",
                arguments.lookup_default(arguments.Method["SAMPLE"],
                                         "save_warmup"),
            )
            payloads.append(payload)

        async def go():
            io = ConsoleIO()
            sampling_output = io.section().error_output
            percent_complete = 0
            sampling_output.write_line(
                f"<comment>Sampling:</comment> {percent_complete:3.0f}%")

            current_and_max_iterations_re = re.compile(
                r"Iteration:\s+(\d+)\s+/\s+(\d+)")
            async with stan.common.HttpstanClient() as client:
                operations = []
                for payload in payloads:
                    resp = await client.post(f"/{self.model_name}/fits",
                                             json=payload)
                    if resp.status == 422:
                        raise ValueError(str(resp.json()))
                    elif resp.status != 201:
                        raise RuntimeError(resp.json()["message"])
                    assert resp.status == 201
                    operations.append(resp.json())

                # poll to get progress for each chain until all chains finished
                current_iterations = {}
                while not all(operation["done"] for operation in operations):
                    for operation in operations:
                        if operation["done"]:
                            continue
                        resp = await client.get(f"/{operation['name']}")
                        assert resp.status != 404
                        operation.update(resp.json())
                        progress_message = operation["metadata"].get(
                            "progress")
                        if not progress_message:
                            continue
                        iteration, iteration_max = map(
                            int,
                            current_and_max_iterations_re.findall(
                                progress_message).pop(0))
                        current_iterations[operation["name"]] = iteration
                        iterations_count = sum(current_iterations.values())
                        total_iterations = iteration_max * num_chains
                        percent_complete = 100 * iterations_count / total_iterations
                        sampling_output.clear() if io.supports_ansi(
                        ) else sampling_output.write("\n")
                        sampling_output.write_line(
                            f"<comment>Sampling:</comment> {round(percent_complete):3.0f}% ({iterations_count}/{total_iterations})"
                        )
                    await asyncio.sleep(0.01)

                fit_in_cache = len(current_iterations) < num_chains

                stan_outputs = []
                for operation in operations:
                    fit_name = operation["result"].get("name")
                    if fit_name is None:  # operation["result"] is an error
                        assert not str(operation["result"]["code"]).startswith(
                            "2"), operation
                        message = operation["result"]["message"]
                        if """ValueError('Initialization failed.')""" in message:
                            sampling_output.clear()
                            sampling_output.write_line(
                                "<info>Sampling:</info> <error>Initialization failed.</error>"
                            )
                            raise RuntimeError("Initialization failed.")
                        raise RuntimeError(message)

                    resp = await client.get(f"/{fit_name}")
                    if resp.status != 200:
                        raise RuntimeError((resp.json())["message"])
                    stan_outputs.append(resp.content)

                    # clean up after ourselves when fit is uncacheable (no random seed)
                    if self.random_seed is None:
                        resp = await client.delete(f"/{fit_name}")
                        if resp.status not in {200, 202, 204}:
                            raise RuntimeError((resp.json())["message"])

                sampling_output.clear() if io.supports_ansi(
                ) else sampling_output.write("\n")
                sampling_output.write_line(
                    "<info>Sampling:</info> 100%, done." if fit_in_cache else
                    f"<info>Sampling:</info> {percent_complete:3.0f}% ({iterations_count}/{total_iterations}), done."
                )
                if not io.supports_ansi():
                    sampling_output.write("\n")

            stan_outputs = tuple(
                stan_outputs)  # Fit constructor expects a tuple.

            def is_nonempty_logger_message(msg: simdjson.Object):
                return msg["topic"] == "logger" and msg["values"][0] != "info:"

            def is_iteration_or_elapsed_time_logger_message(
                    msg: simdjson.Object):
                # Assumes `msg` is a message with topic `logger`.
                text = msg["values"][0]
                return (
                    text.startswith("info:Iteration:")
                    or text.startswith("info: Elapsed Time:")
                    # this detects lines following "Elapsed Time:", part of a multi-line Stan message
                    or text.startswith("info:" + " " * 15))

            parser = simdjson.Parser()
            nonstandard_logger_messages = []
            for stan_output in stan_outputs:
                for line in stan_output.splitlines():
                    # Do not attempt to parse non-logger messages. Draws could contain nan or inf values.
                    # simdjson cannot parse lines containing such values.
                    if b'"logger"' not in line:
                        continue
                    msg = parser.parse(line)
                    if is_nonempty_logger_message(
                            msg
                    ) and not is_iteration_or_elapsed_time_logger_message(msg):
                        nonstandard_logger_messages.append(msg.as_dict())
            del parser  # simdjson.Parser is no longer used at this point.

            if nonstandard_logger_messages:
                io.error_line(
                    "<comment>Messages received during sampling:</comment>")
                for msg in nonstandard_logger_messages:
                    text = msg["values"][0].replace("info:", "  ").replace(
                        "error:", "  ")
                    if text.strip():
                        io.error_line(f"{text}")

            fit = stan.fit.Fit(
                stan_outputs,
                num_chains,
                self.param_names,
                self.constrained_param_names,
                self.dims,
                num_warmup,
                num_samples,
                num_thin,
                save_warmup,
            )

            for entry_point in stan.plugins.get_plugins():
                Plugin = entry_point.load()
                fit = Plugin().on_post_sample(fit)
            return fit

        try:
            return asyncio.run(go())
        except KeyboardInterrupt:
            return  # type: ignore
Exemplo n.º 11
0
    def sample(self, **kwargs) -> stan.fit.Fit:
        """Draw samples from the model.

        Parameters in ``kwargs`` will be passed to the default sample function in
        stan::services. Parameter names are identical to those used in CmdStan.
        See the CmdStan documentation for parameter descriptions and default
        values.

        `num_chains` is the lone PyStan-specific keyword argument. It indicates
        the number of independent processes to use when drawing samples.
        The default value is 1.

        Returns:
            Fit: instance of Fit allowing access to draws.

        """
        assert "chain" not in kwargs, "`chain` id is set automatically."
        assert "data" not in kwargs, "`data` is set in `build`."
        assert "random_seed" not in kwargs, "`random_seed` is set in `build`."
        num_chains = kwargs.pop("num_chains", 1)

        init = kwargs.pop("init", [dict() for _ in range(num_chains)])
        if len(init) != num_chains:
            raise ValueError("Initial values must be provided for each chain.")

        payloads = []
        for chain in range(1, num_chains + 1):
            payload = {
                "function": "stan::services::sample::hmc_nuts_diag_e_adapt"
            }
            payload.update(kwargs)
            payload["chain"] = chain  # type: ignore
            payload["data"] = self.data  # type: ignore
            payload["init"] = init.pop(0)
            if self.random_seed is not None:
                payload["random_seed"] = self.random_seed  # type: ignore

            # fit needs to know num_samples, num_warmup, num_thin, save_warmup
            # progress bar needs to know some of these
            num_warmup = payload.get(
                "num_warmup",
                arguments.lookup_default(arguments.Method["SAMPLE"],
                                         "num_warmup"))
            num_samples = payload.get(
                "num_samples",
                arguments.lookup_default(arguments.Method["SAMPLE"],
                                         "num_samples"),
            )
            num_thin = payload.get(
                "num_thin",
                arguments.lookup_default(arguments.Method["SAMPLE"],
                                         "num_thin"))
            save_warmup = payload.get(
                "save_warmup",
                arguments.lookup_default(arguments.Method["SAMPLE"],
                                         "save_warmup"),
            )
            payloads.append(payload)

        async def go():
            io = ConsoleIO()
            io.error_line("<info>Sampling...</info>")
            progress_bar = ProgressBar(io)
            progress_bar.set_format("very_verbose")

            current_and_max_iterations_re = re.compile(
                r"Iteration:\s+(\d+)\s+/\s+(\d+)")
            async with stan.common.httpstan_server() as (host, port):
                fits_url = f"http://{host}:{port}/v1/{self.model_name}/fits"
                operations = []
                for payload in payloads:
                    async with aiohttp.request("POST", fits_url,
                                               json=payload) as resp:
                        if resp.status == 422:
                            raise ValueError(str(await resp.json()))
                        elif resp.status != 201:
                            raise RuntimeError((await resp.json())["message"])
                        assert resp.status == 201
                        operations.append(await resp.json())

                # poll to get progress for each chain until all chains finished
                current_iterations = {}
                while not all(operation["done"] for operation in operations):
                    for operation in operations:
                        if operation["done"]:
                            continue
                        operation_name = operation["name"]
                        async with aiohttp.request(
                                "GET",
                                f"http://{host}:{port}/v1/{operation_name}"
                        ) as resp:
                            operation.update(await resp.json())
                            progress_message = operation["metadata"].get(
                                "progress")
                            if not progress_message:
                                continue
                            iteration, iteration_max = map(
                                int,
                                current_and_max_iterations_re.findall(
                                    progress_message).pop(0))
                            if not progress_bar.get_max_steps(
                            ):  # i.e., has not started
                                progress_bar.start(max=iteration_max *
                                                   num_chains)
                            current_iterations[operation["name"]] = iteration
                            progress_bar.set_progress(
                                sum(current_iterations.values()))
                    await asyncio.sleep(0.01)
                # Sampling has finished. But we do not call `progress_bar.finish()` right
                # now. First we write informational messages to the screen, then we
                # redraw the (complete) progress bar. Only after that do we call `finish`.

                stan_outputs = []
                for operation in operations:
                    fit_name = operation["result"].get("name")
                    if fit_name is None:  # operation["result"] is an error
                        assert not str(operation["result"]["code"]).startswith(
                            "2"), operation
                        raise RuntimeError(operation["result"]["message"])
                    async with aiohttp.request(
                            "GET",
                            f"http://{host}:{port}/v1/{fit_name}") as resp:
                        if resp.status != 200:
                            raise RuntimeError((await resp.json())["message"])
                        stan_outputs.append(await resp.read())
                stan_outputs = tuple(
                    stan_outputs)  # Fit constructor expects a tuple.

                def is_nonempty_logger_message(msg: simdjson.Object):
                    return msg[
                        "topic"] == "logger" and msg["values"][0] != "info:"

                def is_iteration_or_elapsed_time_logger_message(
                        msg: simdjson.Object):
                    # Assumes `msg` is a message with topic `logger`.
                    text = msg["values"][0]
                    return (
                        text.startswith("info:Iteration:")
                        or text.startswith("info: Elapsed Time:")
                        # this detects lines following "Elapsed Time:", part of a multi-line Stan message
                        or text.startswith("info:" + " " * 15))

                parser = simdjson.Parser()
                nonstandard_logger_messages = []
                for stan_output in stan_outputs:
                    for line in stan_output.splitlines():
                        # Do not attempt to parse non-logger messages. Draws could contain nan or inf values.
                        # simdjson cannot parse lines containing such values.
                        if b'"logger"' not in line:
                            continue
                        msg = parser.parse(line)
                        if is_nonempty_logger_message(
                                msg
                        ) and not is_iteration_or_elapsed_time_logger_message(
                                msg):
                            nonstandard_logger_messages.append(msg.as_dict())
                del parser  # simdjson.Parser is no longer used at this point.

                progress_bar.clear()
                io.error("\x08" * progress_bar._last_messages_length
                         )  # move left to start of line
                if nonstandard_logger_messages:
                    io.error_line(
                        "<comment>Messages received during sampling:</comment>"
                    )
                    for msg in nonstandard_logger_messages:
                        text = msg["values"][0].replace("info:", "  ").replace(
                            "error:", "  ")
                        if text.strip():
                            io.error_line(f"{text}")
                progress_bar.display()  # re-draw the (complete) progress bar
                progress_bar.finish()
                io.error_line("\n<info>Done.</info>")

                # clean up after ourselves when fit is uncacheable (no random seed)
                if self.random_seed is None:
                    async with aiohttp.request(
                            "DELETE",
                            f"http://{host}:{port}/v1/{fit_name}") as resp:
                        if resp.status not in {200, 202, 204}:
                            raise RuntimeError((await resp.json())["message"])

            return stan.fit.Fit(
                stan_outputs,
                num_chains,
                self.param_names,
                self.constrained_param_names,
                self.dims,
                num_warmup,
                num_samples,
                num_thin,
                save_warmup,
            )

        try:
            return asyncio.run(go())
        except KeyboardInterrupt:
            return  # type: ignore
Exemplo n.º 12
0
async def call(
    function_name: str,
    model_name: str,
    fit_name: str,
    logger_callback: typing.Optional[typing.Callable] = None,
    **kwargs: dict,
) -> None:
    """Call stan::services function.

    Yields (asynchronously) messages from the stan::callbacks writers which are
    written to by the stan::services function.

    This is a coroutine function.

    Arguments:
        function_name: full name of function in stan::services
        services_module (module): model-specific services extension module
        fit_name: Name of fit, used for saving length-prefixed messages
        logger_callback: Callback function for logger messages, including sampling progress messages
        kwargs: named stan::services function arguments, see CmdStan documentation.
    """
    method, function_basename = function_name.replace("stan::services::",
                                                      "").split("::", 1)

    # Fetch defaults for missing arguments. This is an important step!
    # For example, `random_seed`, if not in `kwargs`, will be set.
    # temporarily load the module to lookup function arguments
    services_module = httpstan.models.import_services_extension_module(
        model_name)
    function_arguments = arguments.function_arguments(function_basename,
                                                      services_module)
    del services_module
    # This is clumsy due to the way default values are available. There is no
    # way to directly lookup the default value for an argument (e.g., `delta`)
    # given both the argument name and the (full) function name (e.g.,
    # `stan::services::hmc_nuts_diag_e_adapt`).
    for arg in function_arguments:
        if arg not in kwargs:
            kwargs[arg] = typing.cast(
                typing.Any,
                arguments.lookup_default(arguments.Method[method.upper()],
                                         arg))

    with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as socket_:
        _, socket_filename = tempfile.mkstemp(prefix="httpstan_",
                                              suffix=".sock")
        os.unlink(socket_filename)
        socket_.bind(socket_filename)
        socket_.listen(
            4)  # three stan callback writers, one stan callback logger

        lazy_function_wrapper = _make_lazy_function_wrapper(
            function_basename, model_name)
        lazy_function_wrapper_partial = functools.partial(
            lazy_function_wrapper, socket_filename, **kwargs)

        # If HTTPSTAN_DEBUG is set block until sampling is complete. Do not use an executor.
        if HTTPSTAN_DEBUG:  # pragma: no cover
            future: asyncio.Future = asyncio.Future()
            logger.debug("Calling stan::services function with debug mode on.")
            print(
                "Warning: httpstan debug mode is on! `num_samples` must be set to a small number (e.g., 10)."
            )
            future.set_result(lazy_function_wrapper_partial())
        else:
            future = asyncio.get_running_loop().run_in_executor(
                executor, lazy_function_wrapper_partial)  # type: ignore

        messages_files: typing.Mapping[socket.socket,
                                       io.BytesIO] = collections.defaultdict(
                                           io.BytesIO)
        potential_readers = [socket_]
        while True:
            # note: timeout of 0.01 seems to work well based on measurements
            readable, writeable, errored = select.select(
                potential_readers, [], [], 0.01)
            for s in readable:
                if s is socket_:
                    conn, _ = s.accept()
                    logger.debug(
                        "Opened socket connection to a socket_logger or socket_writer."
                    )
                    potential_readers.append(conn)
                    continue
                message = s.recv(8192)
                if not len(message):
                    # `close` called on other end
                    s.close()
                    logger.debug(
                        "Closed socket connection to a socket_logger or socket_writer."
                    )
                    potential_readers.remove(s)
                    continue
                # Only trigger callback if message has topic `logger`.
                if logger_callback and b'"logger"' in message:
                    logger_callback(message)
                messages_files[s].write(message)
            # if `potential_readers == [socket_]` then either (1) no connections
            # have been opened or (2) all connections have been closed.
            if not readable:
                if potential_readers == [socket_] and future.done():
                    logger.debug(
                        f"Stan services function `{function_basename}` returned without problems or raised a C++ exception."
                    )
                    break
                # no messages right now and not done. Sleep briefly so other pending tasks get a chance to run.
                await asyncio.sleep(0.001)

    # WISHLIST: Here we compress messages after they all have arrived. Find a way to compress
    # messages as they arrive.  Compressing messages as they arrive would use much less memory.
    with lz4.frame.LZ4FrameCompressor() as compressor:
        compressed = compressor.begin()
        for fh in messages_files.values():
            fh.flush()
            compressed += compressor.compress(fh.getvalue())
            fh.close()
        compressed += compressor.flush()
    httpstan.cache.dump_fit(compressed, fit_name)

    # `result()` method will raise exceptions, if any
    future.result()