Пример #1
0
async def test_function_arguments(api_url: str) -> None:
    """Test function argument name lookup."""

    # function_arguments needs compiled module, so we have to get one
    model_name = await helpers.get_model_name(api_url, program_code)

    # get a reference to the model-specific services extension module
    # the following call sets up database, populates app['db']
    module = httpstan.models.import_services_extension_module(model_name)

    expected = [
        "data",
        "init",
        "random_seed",
        "chain",
        "init_radius",
        "num_warmup",
        "num_samples",
        "num_thin",
        "save_warmup",
        "refresh",
        "stepsize",
        "stepsize_jitter",
        "max_depth",
        "delta",
        "gamma",
        "kappa",
        "t0",
        "init_buffer",
        "term_buffer",
        "window",
    ]

    assert expected == arguments.function_arguments("hmc_nuts_diag_e_adapt", module)
Пример #2
0
async def call(function_name: str, program_module, data: dict, **kwargs):
    """Call stan::services function.

    Yields (asynchronously) messages from the stan::callbacks writers which are
    written to by the stan::services function.

    This is a coroutine function.

    Arguments:
        function_name: full name of function in stan::services
        program_module (module): Stan Program extension module
        data: dictionary with data with which to populate array_var_context
        kwargs: named stan::services function arguments, see CmdStan documentation.
    """
    method, function_basename = function_name.replace('stan::services::',
                                                      '').split('::', 1)
    queue_wrapper = httpstan.spsc_queue.SPSCQueue(capacity=10000)
    array_var_context_capsule = httpstan.stan.make_array_var_context(data)
    function_wrapper = getattr(program_module, function_basename + '_wrapper')

    # fetch defaults for missing arguments
    function_arguments = arguments.function_arguments(function_basename,
                                                      program_module)
    # This is clumsy due to the way default values are available. There is no
    # way to directly lookup the default value for an argument (e.g., `delta`)
    # given both the argument name and the (full) function name (e.g.,
    # `stan::services::hmc_nuts_diag_e_adapt`).
    for arg in function_arguments:
        if arg not in kwargs:
            kwargs[arg] = arguments.lookup_default(
                arguments.Method[method.upper()], arg)
    function_wrapper_partial = functools.partial(function_wrapper,
                                                 array_var_context_capsule,
                                                 queue_wrapper.to_capsule(),
                                                 **kwargs)

    # WISHLIST: can one use ProcessPoolExecutor somehow on Linux and OSX?
    loop = asyncio.get_event_loop()
    future = asyncio.ensure_future(
        loop.run_in_executor(None, function_wrapper_partial))  # type: ignore
    parser = httpstan.callbacks_writer_parser.WriterParser()
    while True:
        try:
            message = queue_wrapper.get_nowait()
        except queue.Empty:
            if future.done():
                break
            await asyncio.sleep(0.1)
            continue
        parsed = parser.parse(message.decode())
        # parsed is None if the message was a blank line or a header with param names
        if parsed:
            yield parsed
    future.result()  # raises exceptions from task, if any
Пример #3
0
    async def main():
        async with aiohttp.ClientSession() as session:
            programs_url = 'http://{}:{}/v1/programs'.format(host, port)
            data = {'program_code': program_code}
            async with session.post(programs_url, data=json.dumps(data), headers=headers) as resp:
                assert resp.status == 200
                program_id = (await resp.json())['id']

        # get a reference to the program_module
        app = {}  # mock aiohttp.web.Application
        await httpstan.cache.init_cache(app)  # setup database, populates app['db']
        module_bytes = await httpstan.cache.load_program_extension_module(program_id, app['db'])
        assert module_bytes is not None
        program_module = httpstan.program.load_program_extension_module(program_id, module_bytes)

        expected = ['random_seed', 'chain', 'init_radius', 'num_warmup',
                    'num_samples', 'num_thin', 'save_warmup', 'refresh',
                    'stepsize', 'stepsize_jitter', 'max_depth']

        assert expected == arguments.function_arguments('hmc_nuts_diag_e', program_module)
Пример #4
0
    async def main():
        async with aiohttp.ClientSession() as session:
            models_url = "http://{}:{}/v1/models".format(host, port)
            data = {"program_code": program_code}
            async with session.post(models_url,
                                    data=json.dumps(data),
                                    headers=headers) as resp:
                assert resp.status == 200
                model_id = (await resp.json())["id"]

        # get a reference to the model_module
        app = {}  # mock aiohttp.web.Application
        await httpstan.cache.init_cache(
            app)  # setup database, populates app['db']
        module_bytes = await httpstan.cache.load_model_extension_module(
            model_id, app["db"])
        assert module_bytes is not None
        model_module = httpstan.models.load_model_extension_module(
            model_id, module_bytes)

        expected = [
            "random_seed",
            "chain",
            "init_radius",
            "num_warmup",
            "num_samples",
            "num_thin",
            "save_warmup",
            "refresh",
            "stepsize",
            "stepsize_jitter",
            "max_depth",
        ]

        assert expected == arguments.function_arguments(
            "hmc_nuts_diag_e", model_module)
Пример #5
0
async def call(
    function_name: str,
    model_name: str,
    fit_name: str,
    logger_callback: typing.Optional[typing.Callable] = None,
    **kwargs: dict,
) -> None:
    """Call stan::services function.

    Yields (asynchronously) messages from the stan::callbacks writers which are
    written to by the stan::services function.

    This is a coroutine function.

    Arguments:
        function_name: full name of function in stan::services
        services_module (module): model-specific services extension module
        fit_name: Name of fit, used for saving length-prefixed messages
        logger_callback: Callback function for logger messages, including sampling progress messages
        kwargs: named stan::services function arguments, see CmdStan documentation.
    """
    method, function_basename = function_name.replace("stan::services::",
                                                      "").split("::", 1)

    # Fetch defaults for missing arguments. This is an important step!
    # For example, `random_seed`, if not in `kwargs`, will be set.
    # temporarily load the module to lookup function arguments
    services_module = httpstan.models.import_services_extension_module(
        model_name)
    function_arguments = arguments.function_arguments(function_basename,
                                                      services_module)
    del services_module
    # This is clumsy due to the way default values are available. There is no
    # way to directly lookup the default value for an argument (e.g., `delta`)
    # given both the argument name and the (full) function name (e.g.,
    # `stan::services::hmc_nuts_diag_e_adapt`).
    for arg in function_arguments:
        if arg not in kwargs:
            kwargs[arg] = typing.cast(
                typing.Any,
                arguments.lookup_default(arguments.Method[method.upper()],
                                         arg))

    with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as socket_:
        _, socket_filename = tempfile.mkstemp(prefix="httpstan_",
                                              suffix=".sock")
        os.unlink(socket_filename)
        socket_.bind(socket_filename)
        socket_.listen(
            4)  # three stan callback writers, one stan callback logger

        lazy_function_wrapper = _make_lazy_function_wrapper(
            function_basename, model_name)
        lazy_function_wrapper_partial = functools.partial(
            lazy_function_wrapper, socket_filename, **kwargs)

        # If HTTPSTAN_DEBUG is set block until sampling is complete. Do not use an executor.
        if HTTPSTAN_DEBUG:  # pragma: no cover
            future: asyncio.Future = asyncio.Future()
            logger.debug("Calling stan::services function with debug mode on.")
            print(
                "Warning: httpstan debug mode is on! `num_samples` must be set to a small number (e.g., 10)."
            )
            future.set_result(lazy_function_wrapper_partial())
        else:
            future = asyncio.get_running_loop().run_in_executor(
                executor, lazy_function_wrapper_partial)  # type: ignore

        messages_files: typing.Mapping[socket.socket,
                                       io.BytesIO] = collections.defaultdict(
                                           io.BytesIO)
        # using a wbits value which makes things compatible with gzip
        messages_compressobjs: typing.Mapping[
            socket.socket, zlib._Compress] = collections.defaultdict(
                functools.partial(zlib.compressobj,
                                  level=zlib.Z_BEST_SPEED,
                                  wbits=zlib.MAX_WBITS | 16))
        potential_readers = [socket_]

        while True:
            # note: timeout of 0.01 seems to work well based on measurements
            readable, writeable, errored = select.select(
                potential_readers, [], [], 0.01)
            for s in readable:
                if s is socket_:
                    conn, _ = s.accept()
                    logger.debug(
                        "Opened socket connection to a socket_logger or socket_writer."
                    )
                    potential_readers.append(conn)
                    continue
                message = s.recv(8192)
                if not len(message):
                    # `close` called on other end
                    s.close()
                    logger.debug(
                        "Closed socket connection to a socket_logger or socket_writer."
                    )
                    potential_readers.remove(s)
                    continue
                # Only trigger callback if message has topic `logger`.
                if logger_callback and b'"logger"' in message:
                    logger_callback(message)
                messages_files[s].write(
                    messages_compressobjs[s].compress(message))
            # if `potential_readers == [socket_]` then either (1) no connections
            # have been opened or (2) all connections have been closed.
            if not readable:
                if potential_readers == [socket_] and future.done():
                    logger.debug(
                        f"Stan services function `{function_basename}` returned without problems or raised a C++ exception."
                    )
                    break
                # no messages right now and not done. Sleep briefly so other pending tasks get a chance to run.
                await asyncio.sleep(0.001)

    compressed_parts = []
    for s, fh in messages_files.items():
        fh.write(messages_compressobjs[s].flush())
        fh.flush()
        compressed_parts.append(fh.getvalue())
        fh.close()
    httpstan.cache.dump_fit(b"".join(compressed_parts), fit_name)

    # if an exception has already occurred, grab relevant info messages, add as context
    exception = future.exception()
    if exception and len(exception.args) == 1:
        import gzip
        import json

        original_exception_message = exception.args[
            0]  # e.g., from ValueError("Initialization failed.")
        info_messages_for_context = []
        num_context_messages = 4

        jsonlines = gzip.decompress(b"".join(compressed_parts)).decode()
        for line in jsonlines.split("\n")[:num_context_messages]:
            try:
                message = json.loads(line)
                info_message = message["values"].pop().replace("info:", "")
                info_messages_for_context.append(info_message.strip())
            except json.JSONDecodeError:
                pass
        # add the info messages to the original exception message. For example,
        # ValueError("Initialization failed.") -> ValueError("Initialization failed. Rejecting initial value: Log probability ...")
        if info_messages_for_context:
            new_exception_message = f"{original_exception_message} {' '.join(info_messages_for_context)} ..."
            exception.args = (new_exception_message, )

    # `result()` method will raise exceptions, if any
    future.result()
Пример #6
0
async def call(
    function_name: str,
    model_name: str,
    fit_name: str,
    logger_callback: typing.Optional[typing.Callable] = None,
    **kwargs: dict,
) -> None:
    """Call stan::services function.

    Yields (asynchronously) messages from the stan::callbacks writers which are
    written to by the stan::services function.

    This is a coroutine function.

    Arguments:
        function_name: full name of function in stan::services
        services_module (module): model-specific services extension module
        fit_name: Name of fit, used for saving length-prefixed messages
        logger_callback: Callback function for logger messages, including sampling progress messages
        kwargs: named stan::services function arguments, see CmdStan documentation.
    """
    method, function_basename = function_name.replace("stan::services::",
                                                      "").split("::", 1)

    # Fetch defaults for missing arguments. This is an important step!
    # For example, `random_seed`, if not in `kwargs`, will be set.
    # temporarily load the module to lookup function arguments
    services_module = httpstan.models.import_services_extension_module(
        model_name)
    function_arguments = arguments.function_arguments(function_basename,
                                                      services_module)
    del services_module
    # This is clumsy due to the way default values are available. There is no
    # way to directly lookup the default value for an argument (e.g., `delta`)
    # given both the argument name and the (full) function name (e.g.,
    # `stan::services::hmc_nuts_diag_e_adapt`).
    for arg in function_arguments:
        if arg not in kwargs:
            kwargs[arg] = typing.cast(
                typing.Any,
                arguments.lookup_default(arguments.Method[method.upper()],
                                         arg))

    with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as socket_:
        _, socket_filename = tempfile.mkstemp(prefix="httpstan_",
                                              suffix=".sock")
        os.unlink(socket_filename)
        socket_.bind(socket_filename)
        socket_.listen(
            4)  # three stan callback writers, one stan callback logger

        lazy_function_wrapper = _make_lazy_function_wrapper(
            function_basename, model_name)
        lazy_function_wrapper_partial = functools.partial(
            lazy_function_wrapper, socket_filename, **kwargs)

        # If HTTPSTAN_DEBUG is set block until sampling is complete. Do not use an executor.
        if HTTPSTAN_DEBUG:  # pragma: no cover
            future: asyncio.Future = asyncio.Future()
            logger.debug("Calling stan::services function with debug mode on.")
            print(
                "Warning: httpstan debug mode is on! `num_samples` must be set to a small number (e.g., 10)."
            )
            future.set_result(lazy_function_wrapper_partial())
        else:
            future = asyncio.get_running_loop().run_in_executor(
                executor, lazy_function_wrapper_partial)  # type: ignore

        messages_files: typing.Mapping[socket.socket,
                                       io.BytesIO] = collections.defaultdict(
                                           io.BytesIO)
        potential_readers = [socket_]
        while True:
            # note: timeout of 0.01 seems to work well based on measurements
            readable, writeable, errored = select.select(
                potential_readers, [], [], 0.01)
            for s in readable:
                if s is socket_:
                    conn, _ = s.accept()
                    logger.debug(
                        "Opened socket connection to a socket_logger or socket_writer."
                    )
                    potential_readers.append(conn)
                    continue
                message = s.recv(8192)
                if not len(message):
                    # `close` called on other end
                    s.close()
                    logger.debug(
                        "Closed socket connection to a socket_logger or socket_writer."
                    )
                    potential_readers.remove(s)
                    continue
                # Only trigger callback if message has topic `logger`.
                if logger_callback and b'"logger"' in message:
                    logger_callback(message)
                messages_files[s].write(message)
            # if `potential_readers == [socket_]` then either (1) no connections
            # have been opened or (2) all connections have been closed.
            if not readable:
                if potential_readers == [socket_] and future.done():
                    logger.debug(
                        f"Stan services function `{function_basename}` returned without problems or raised a C++ exception."
                    )
                    break
                # no messages right now and not done. Sleep briefly so other pending tasks get a chance to run.
                await asyncio.sleep(0.001)

    # WISHLIST: Here we compress messages after they all have arrived. Find a way to compress
    # messages as they arrive.  Compressing messages as they arrive would use much less memory.
    with lz4.frame.LZ4FrameCompressor() as compressor:
        compressed = compressor.begin()
        for fh in messages_files.values():
            fh.flush()
            compressed += compressor.compress(fh.getvalue())
            fh.close()
        compressed += compressor.flush()
    httpstan.cache.dump_fit(compressed, fit_name)

    # `result()` method will raise exceptions, if any
    future.result()