Example #1
0
def new_proc(
    name: str,
    actor: Actor,
    # passed through to actor main
    bind_addr: Tuple[str, int],
    parent_addr: Tuple[str, int],
) -> mp.Process:
    """Create a new ``multiprocessing.Process`` using the
    spawn method as configured using ``try_set_start_method()``.
    """
    start_method = _ctx.get_start_method()
    if start_method == 'forkserver':
        # XXX do our hackery on the stdlib to avoid multiple
        # forkservers (one at each subproc layer).
        fs = forkserver._forkserver
        curr_actor = current_actor()
        if is_main_process() and not curr_actor._forkserver_info:
            # if we're the "main" process start the forkserver only once
            # and pass its ipc info to downstream children
            # forkserver.set_forkserver_preload(rpc_module_paths)
            forkserver.ensure_running()
            fs_info = (
                fs._forkserver_address,
                fs._forkserver_alive_fd,
                getattr(fs, '_forkserver_pid', None),
                getattr(semaphore_tracker._semaphore_tracker, '_pid', None),
                semaphore_tracker._semaphore_tracker._fd,
            )
        else:
            assert curr_actor._forkserver_info
            fs_info = (
                fs._forkserver_address,
                fs._forkserver_alive_fd,
                fs._forkserver_pid,
                semaphore_tracker._semaphore_tracker._pid,
                semaphore_tracker._semaphore_tracker._fd,
             ) = curr_actor._forkserver_info
    else:
        fs_info = (None, None, None, None, None)

    return _ctx.Process(
        target=actor._fork_main,
        args=(
            bind_addr,
            fs_info,
            start_method,
            parent_addr
        ),
        # daemon=True,
        name=name,
    )
Example #2
0
def main(args: Iterable[str]) -> int:
    """
    The main program loop

    :param args: Command line arguments
    :return: The program exit code
    """
    # Handle command line arguments
    args = handle_args(args)
    set_verbosity_logger(logger, args.verbosity)

    # Go to the working directory
    config_file = os.path.realpath(args.config)
    os.chdir(os.path.dirname(config_file))

    try:
        # Read the configuration
        config = config_parser.load_config(config_file)
    except (ConfigurationSyntaxError, DataConversionError) as e:
        # Make the config exceptions a bit more readable
        msg = e.message
        if e.lineno and e.lineno != -1:
            msg += ' on line {}'.format(e.lineno)
        if e.url:
            parts = urlparse(e.url)
            msg += ' in {}'.format(parts.path)
        logger.critical(msg)
        return 1
    except ValueError as e:
        logger.critical(e)
        return 1

    # Immediately drop privileges in a non-permanent way so we create logs with the correct owner
    drop_privileges(config.user, config.group, permanent=False)

    # Trigger the forkserver at this point, with dropped privileges, and ignoring KeyboardInterrupt
    signal.signal(signal.SIGINT, signal.SIG_IGN)
    multiprocessing.set_start_method('forkserver')
    forkserver.ensure_running()

    # Initialise the logger
    config.logging.configure(logger, verbosity=args.verbosity)
    logger.info("Starting Python DHCPv6 server v{}".format(
        dhcpkit.__version__))

    # Create our selector
    sel = selectors.DefaultSelector()

    # Convert signals to messages on a pipe
    signal_r, signal_w = os.pipe()
    flags = fcntl.fcntl(signal_w, fcntl.F_GETFL, 0)
    flags = flags | os.O_NONBLOCK
    fcntl.fcntl(signal_w, fcntl.F_SETFL, flags)
    signal.set_wakeup_fd(signal_w)
    sel.register(signal_r, selectors.EVENT_READ)

    # Ignore normal signal handling by attaching dummy handlers (SIG_IGN will not put messages on the pipe)
    signal.signal(signal.SIGINT, lambda signum, frame: None)
    signal.signal(signal.SIGTERM, lambda signum, frame: None)
    signal.signal(signal.SIGHUP, lambda signum, frame: None)
    signal.signal(signal.SIGUSR1, lambda signum, frame: None)

    # Excessive exception catcher
    exception_history = []

    # Some stats
    message_count = 0

    # Create a queue for our children to log to
    logging_queue = multiprocessing.Queue()

    statistics = ServerStatistics()
    listeners = []
    control_socket = None
    stopping = False

    while not stopping:
        # Safety first: assume we want to quit when we break the inner loop unless told otherwise
        stopping = True

        # Initialise the logger again
        lowest_log_level = config.logging.configure(logger,
                                                    verbosity=args.verbosity)

        # Enable multiprocessing logging, mostly useful for development
        mp_logger = get_logger()
        mp_logger.propagate = config.logging.log_multiprocessing

        global logging_thread
        if logging_thread:
            logging_thread.stop()

        logging_thread = queue_logger.QueueLevelListener(
            logging_queue, *logger.handlers)
        logging_thread.start()

        # Use the logging queue in the main process as well so messages don't get out of order
        logging_handler = WorkerQueueHandler(logging_queue)
        logging_handler.setLevel(lowest_log_level)
        logger.handlers = [logging_handler]

        # Restore our privileges while we write the PID file and open network listeners
        restore_privileges()

        # Open the network listeners
        old_listeners = listeners
        listeners = []
        for listener_factory in config.listener_factories:
            # Create new listener while trying to re-use existing sockets
            listeners.append(listener_factory(old_listeners + listeners))

        # Forget old listeners
        del old_listeners

        # Write the PID file
        pid_filename = create_pidfile(args=args, config=config)

        # Create a control socket
        if control_socket:
            sel.unregister(control_socket)
            control_socket.close()

        control_socket = create_control_socket(args=args, config=config)
        if control_socket:
            sel.register(control_socket, selectors.EVENT_READ)

        # And Drop privileges again
        drop_privileges(config.user, config.group, permanent=False)

        # Remove any file descriptors from the previous config
        for fd, key in list(sel.get_map().items()):
            # Don't remove our signal handling pipe, control socket, still existing listeners and control connections
            if key.fileobj is signal_r \
                    or (control_socket and key.fileobj is control_socket) \
                    or key.fileobj in listeners \
                    or isinstance(key.fileobj, ControlConnection):
                continue

            # Seems we don't need this one anymore
            sel.unregister(key.fileobj)

        # Collect all the file descriptors we want to listen to
        existing_listeners = [key.fileobj for key in sel.get_map().values()]
        for listener in listeners:
            if listener not in existing_listeners:
                sel.register(listener, selectors.EVENT_READ)

        # Configuration tree
        try:
            message_handler = config.create_message_handler()
        except Exception as e:
            if args.verbosity >= 3:
                logger.exception("Error initialising DHCPv6 server")
            else:
                logger.critical(
                    "Error initialising DHCPv6 server: {}".format(e))
            return 1

        # Make sure we have space to store all the interface statistics
        statistics.set_categories(config.statistics)

        # Start worker processes
        my_pid = os.getpid()
        with NonBlockingPool(processes=config.workers,
                             initializer=setup_worker,
                             initargs=(message_handler, logging_queue,
                                       lowest_log_level, statistics,
                                       my_pid)) as pool:

            logger.info("Python DHCPv6 server is ready to handle requests")

            running = True
            while running:
                count_exception = False

                # noinspection PyBroadException
                try:
                    events = sel.select()
                    for key, mask in events:
                        if isinstance(key.fileobj, Listener):
                            try:
                                packet, replier = key.fileobj.recv_request()

                                # Update stats
                                message_count += 1

                                # Dispatch
                                pool.apply_async(handle_message,
                                                 args=(packet, replier),
                                                 error_callback=error_callback)
                            except IgnoreMessage:
                                # Message isn't complete, leave it for now
                                pass
                            except ClosedListener:
                                # This listener is closed (at least TCP shutdown for incoming data), so forget about it
                                sel.unregister(key.fileobj)
                                listeners.remove(key.fileobj)

                        elif isinstance(key.fileobj, ListenerCreator):
                            # Activity on this object means we have a new listener
                            new_listener = key.fileobj.create_listener()
                            if new_listener:
                                sel.register(new_listener,
                                             selectors.EVENT_READ)
                                listeners.append(new_listener)

                        # Handle signal notifications
                        elif key.fileobj == signal_r:
                            signal_nr = os.read(signal_r, 1)
                            if signal_nr[0] in (signal.SIGHUP, ):
                                # SIGHUP tells the server to reload
                                try:
                                    # Read the new configuration
                                    config = config_parser.load_config(
                                        config_file)
                                except (ConfigurationSyntaxError,
                                        DataConversionError) as e:
                                    # Make the config exceptions a bit more readable
                                    msg = "Not reloading: " + str(e.message)
                                    if e.lineno and e.lineno != -1:
                                        msg += ' on line {}'.format(e.lineno)
                                    if e.url:
                                        parts = urlparse(e.url)
                                        msg += ' in {}'.format(parts.path)
                                    logger.critical(msg)
                                    continue

                                except ValueError as e:
                                    logger.critical("Not reloading: " + str(e))
                                    continue

                                logger.info(
                                    "DHCPv6 server restarting after configuration change"
                                )
                                running = False
                                stopping = False
                                continue

                            elif signal_nr[0] in (signal.SIGINT,
                                                  signal.SIGTERM):
                                logger.debug("Received termination request")

                                running = False
                                stopping = True
                                break

                            elif signal_nr[0] in (signal.SIGUSR1, ):
                                # The USR1 signal is used to indicate initialisation errors in worker processes
                                count_exception = True

                        elif isinstance(key.fileobj, ControlSocket):
                            # A new control connection request
                            control_connection = key.fileobj.accept()
                            if control_connection:
                                # We got a connection, listen to events
                                sel.register(control_connection,
                                             selectors.EVENT_READ)

                        elif isinstance(key.fileobj, ControlConnection):
                            # Let the connection handle received data
                            control_connection = key.fileobj
                            commands = control_connection.get_commands()
                            for command in commands:
                                if command:
                                    logger.debug(
                                        "Received control command '{}'".format(
                                            command))

                                if command == 'help':
                                    control_connection.send(
                                        "Recognised commands:")
                                    control_connection.send("  help")
                                    control_connection.send("  stats")
                                    control_connection.send("  stats-json")
                                    control_connection.send("  reload")
                                    control_connection.send("  shutdown")
                                    control_connection.send("  quit")
                                    control_connection.acknowledge()

                                elif command == 'stats':
                                    control_connection.send(str(statistics))
                                    control_connection.acknowledge()

                                elif command == 'stats-json':
                                    control_connection.send(
                                        json.dumps(statistics.export()))
                                    control_connection.acknowledge()

                                elif command == 'reload':
                                    # Simulate a SIGHUP to reload
                                    os.write(signal_w, bytes([signal.SIGHUP]))
                                    control_connection.acknowledge('Reloading')

                                elif command == 'shutdown':
                                    # Simulate a SIGTERM to reload
                                    control_connection.acknowledge(
                                        'Shutting down')
                                    control_connection.close()
                                    sel.unregister(control_connection)

                                    os.write(signal_w, bytes([signal.SIGTERM]))
                                    break

                                elif command == 'quit' or command is None:
                                    if command == 'quit':
                                        # User nicely signing off
                                        control_connection.acknowledge()

                                    control_connection.close()
                                    sel.unregister(control_connection)
                                    break

                                else:
                                    logger.warning(
                                        "Rejecting unknown control command '{}'"
                                        .format(command))
                                    control_connection.reject()

                except Exception as e:
                    # Catch-all exception handler
                    logger.exception(
                        "Caught unexpected exception {!r}".format(e))
                    count_exception = True

                if count_exception:
                    now = time.monotonic()

                    # Add new exception time to the history
                    exception_history.append(now)

                    # Remove exceptions outside the window from the history
                    cutoff = now - config.exception_window
                    while exception_history and exception_history[0] < cutoff:
                        exception_history.pop(0)

                    # Did we receive too many exceptions shortly after each other?
                    if len(exception_history) > config.max_exceptions:
                        logger.critical(
                            "Received more than {} exceptions in {} seconds, "
                            "exiting".format(config.max_exceptions,
                                             config.exception_window))
                        running = False
                        stopping = True

            pool.close()
            pool.join()

        # Regain root so we can delete the PID file and control socket
        restore_privileges()
        try:
            if pid_filename:
                os.unlink(pid_filename)
                logger.info("Removing PID-file {}".format(pid_filename))
        except OSError:
            pass

        try:
            if control_socket:
                os.unlink(control_socket.socket_path)
                logger.info("Removing control socket {}".format(
                    control_socket.socket_path))
        except OSError:
            pass

    logger.info("Shutting down Python DHCPv6 server v{}".format(
        dhcpkit.__version__))

    return 0
Example #3
0
async def mp_new_proc(
        name: str,
        actor_nursery: 'ActorNursery',  # type: ignore  # noqa
        subactor: Actor,
        errors: dict[tuple[str, str], Exception],
        # passed through to actor main
        bind_addr: tuple[str, int],
        parent_addr: tuple[str, int],
        _runtime_vars: dict[str, Any],  # serialized and sent to _child
        *,
        infect_asyncio: bool = False,
        task_status: TaskStatus[Portal] = trio.TASK_STATUS_IGNORED) -> None:

    # uggh zone
    try:
        from multiprocessing import semaphore_tracker  # type: ignore
        resource_tracker = semaphore_tracker
        resource_tracker._resource_tracker = resource_tracker._semaphore_tracker  # noqa
    except ImportError:
        # 3.8 introduces a more general version that also tracks shared mems
        from multiprocessing import resource_tracker  # type: ignore

    assert _ctx
    start_method = _ctx.get_start_method()
    if start_method == 'forkserver':
        from multiprocessing import forkserver  # type: ignore
        # XXX do our hackery on the stdlib to avoid multiple
        # forkservers (one at each subproc layer).
        fs = forkserver._forkserver
        curr_actor = current_actor()
        if is_main_process() and not curr_actor._forkserver_info:
            # if we're the "main" process start the forkserver
            # only once and pass its ipc info to downstream
            # children
            # forkserver.set_forkserver_preload(enable_modules)
            forkserver.ensure_running()
            fs_info = (
                fs._forkserver_address,
                fs._forkserver_alive_fd,
                getattr(fs, '_forkserver_pid', None),
                getattr(resource_tracker._resource_tracker, '_pid', None),
                resource_tracker._resource_tracker._fd,
            )
        else:
            assert curr_actor._forkserver_info
            fs_info = (
                fs._forkserver_address,
                fs._forkserver_alive_fd,
                fs._forkserver_pid,
                resource_tracker._resource_tracker._pid,
                resource_tracker._resource_tracker._fd,
            ) = curr_actor._forkserver_info
    else:
        fs_info = (None, None, None, None, None)

    proc: mp.Process = _ctx.Process(  # type: ignore
        target=_mp_main,
        args=(
            subactor,
            bind_addr,
            fs_info,
            start_method,
            parent_addr,
            infect_asyncio,
        ),
        # daemon=True,
        name=name,
    )

    # `multiprocessing` only (since no async interface):
    # register the process before start in case we get a cancel
    # request before the actor has fully spawned - then we can wait
    # for it to fully come up before sending a cancel request
    actor_nursery._children[subactor.uid] = (subactor, proc, None)

    proc.start()
    if not proc.is_alive():
        raise ActorFailure("Couldn't start sub-actor?")

    log.runtime(f"Started {proc}")

    try:
        # wait for actor to spawn and connect back to us
        # channel should have handshake completed by the
        # local actor by the time we get a ref to it
        event, chan = await actor_nursery._actor.wait_for_peer(subactor.uid)

        # XXX: monkey patch poll API to match the ``subprocess`` API..
        # not sure why they don't expose this but kk.
        proc.poll = lambda: proc.exitcode  # type: ignore

        # except:
        # TODO: in the case we were cancelled before the sub-proc
        # registered itself back we must be sure to try and clean
        # any process we may have started.

        portal = Portal(chan)
        actor_nursery._children[subactor.uid] = (subactor, proc, portal)

        # unblock parent task
        task_status.started(portal)

        # wait for ``ActorNursery`` block to signal that
        # subprocesses can be waited upon.
        # This is required to ensure synchronization
        # with user code that may want to manually await results
        # from nursery spawned sub-actors. We don't want the
        # containing nurseries here to collect results or error
        # while user code is still doing it's thing. Only after the
        # nursery block closes do we allow subactor results to be
        # awaited and reported upwards to the supervisor.
        with trio.CancelScope(shield=True):
            await actor_nursery._join_procs.wait()

        async with trio.open_nursery() as nursery:
            if portal in actor_nursery._cancel_after_result_on_exit:
                nursery.start_soon(cancel_on_completion, portal, subactor,
                                   errors)

            # This is a "soft" (cancellable) join/reap which
            # will remote cancel the actor on a ``trio.Cancelled``
            # condition.
            await soft_wait(proc, proc_waiter, portal)

            # cancel result waiter that may have been spawned in
            # tandem if not done already
            log.warning("Cancelling existing result waiter task for "
                        f"{subactor.uid}")
            nursery.cancel_scope.cancel()

    finally:
        # hard reap sequence
        if proc.is_alive():
            log.cancel(f"Attempting to hard kill {proc}")
            with trio.move_on_after(0.1) as cs:
                cs.shield = True
                await proc_waiter(proc)

            if cs.cancelled_caught:
                proc.terminate()

        proc.join()
        log.debug(f"Joined {proc}")

        # pop child entry to indicate we are no longer managing subactor
        subactor, proc, portal = actor_nursery._children.pop(subactor.uid)
Example #4
0
def main(args: Iterable[str]) -> int:
    """
    The main program loop

    :param args: Command line arguments
    :return: The program exit code
    """
    # Handle command line arguments
    args = handle_args(args)
    set_verbosity_logger(logger, args.verbosity)

    # Go to the working directory
    config_file = os.path.realpath(args.config)
    os.chdir(os.path.dirname(config_file))

    try:
        # Read the configuration
        config = config_parser.load_config(config_file)
    except (ConfigurationSyntaxError, DataConversionError) as e:
        # Make the config exceptions a bit more readable
        msg = e.message
        if e.lineno and e.lineno != -1:
            msg += ' on line {}'.format(e.lineno)
        if e.url:
            parts = urlparse(e.url)
            msg += ' in {}'.format(parts.path)
        logger.critical(msg)
        return 1
    except ValueError as e:
        logger.critical(e)
        return 1

    # Immediately drop privileges in a non-permanent way so we create logs with the correct owner
    drop_privileges(config.user, config.group, permanent=False)

    # Trigger the forkserver at this point, with dropped privileges, and ignoring KeyboardInterrupt
    signal.signal(signal.SIGINT, signal.SIG_IGN)
    multiprocessing.set_start_method('forkserver')
    forkserver.ensure_running()

    # Initialise the logger
    config.logging.configure(logger, verbosity=args.verbosity)
    logger.info("Starting Python DHCPv6 server v{}".format(dhcpkit.__version__))

    # Create our selector
    sel = selectors.DefaultSelector()

    # Convert signals to messages on a pipe
    signal_r, signal_w = os.pipe()
    flags = fcntl.fcntl(signal_w, fcntl.F_GETFL, 0)
    flags = flags | os.O_NONBLOCK
    fcntl.fcntl(signal_w, fcntl.F_SETFL, flags)
    signal.set_wakeup_fd(signal_w)
    sel.register(signal_r, selectors.EVENT_READ)

    # Ignore normal signal handling by attaching dummy handlers (SIG_IGN will not put messages on the pipe)
    signal.signal(signal.SIGINT, lambda signum, frame: None)
    signal.signal(signal.SIGTERM, lambda signum, frame: None)
    signal.signal(signal.SIGHUP, lambda signum, frame: None)
    signal.signal(signal.SIGINFO, lambda signum, frame: None)

    # Excessive exception catcher
    exception_history = []

    # Some stats
    message_count = 0

    # Initialise the logger again
    config.logging.configure(logger, verbosity=args.verbosity)

    # Create a queue for our children to log to
    logging_queue = multiprocessing.Queue()

    global logging_thread
    logging_thread = queue_logger.QueueLevelListener(logging_queue, *logger.handlers)
    logging_thread.start()

    # Enable multiprocessing logging, mostly useful for development
    if config.logging.log_multiprocessing:
        mp_logger = get_logger()
        mp_logger.propagate = True

    # This will be where we store the new config after a reload
    listeners = []
    stopping = False
    while not stopping:
        # Safety first: assume we want to quit when we break the inner loop unless told otherwise
        stopping = True

        # Restore our privileges while we write the PID file and open network listeners
        restore_privileges()

        # Open the network listeners
        old_listeners = listeners
        listeners = []
        for listener_factory in config.listener_factories:
            # Create new listener while trying to re-use existing sockets
            listeners.append(listener_factory(old_listeners + listeners))

        # Write the PID file
        pid_filename = create_pidfile(args=args, config=config)

        # And Drop privileges again
        drop_privileges(config.user, config.group, permanent=False)

        # Remove any file descriptors from the previous config
        for fd, key in list(sel.get_map().items()):
            # Don't remove our signal handling pipe and still existing listeners
            if key.fileobj == signal_r or key.fileobj in listeners:
                continue

            # Seems we don't need this one anymore
            sel.unregister(key.fileobj)

        # Collect all the file descriptors we want to listen to
        existing_listeners = [key.fileobj for key in sel.get_map().values()]
        for listener in listeners:
            if listener not in existing_listeners:
                sel.register(listener, selectors.EVENT_READ)

        # Configuration tree
        message_handler = config.create_message_handler()

        # Start worker processes
        with multiprocessing.Pool(processes=config.workers,
                                  initializer=setup_worker, initargs=(message_handler, logging_queue)) as pool:

            logger.info("Python DHCPv6 server is ready to handle requests")

            running = True
            while running:
                # noinspection PyBroadException
                try:
                    events = sel.select()
                    for key, mask in events:
                        # Handle signal notifications
                        if key.fileobj == signal_r:
                            signal_nr = os.read(signal_r, 1)
                            if signal_nr[0] in (signal.SIGHUP,):
                                # SIGHUP tells the server to reload
                                try:
                                    # Read the new configuration
                                    config = config_parser.load_config(config_file)

                                    running = False
                                    stopping = False

                                    logger.info("DHCPv6 server restarting after configuration change")

                                    break

                                except (ConfigurationSyntaxError, DataConversionError) as e:
                                    # Make the config exceptions a bit more readable
                                    msg = "Not reloading: " + str(e.message)
                                    if e.lineno and e.lineno != -1:
                                        msg += ' on line {}'.format(e.lineno)
                                    if e.url:
                                        parts = urlparse(e.url)
                                        msg += ' in {}'.format(parts.path)
                                    logger.critical(msg)
                                    return 1
                                except ValueError as e:
                                    logger.critical("Not reloading: " + str(e))
                                    return 1

                            elif signal_nr[0] in (signal.SIGINT, signal.SIGTERM):
                                logger.debug("Received termination request")

                                running = False
                                stopping = True
                                break

                            elif signal_nr[0] in (signal.SIGINFO,):
                                logger.info("Server has processed {} messages".format(message_count))

                            # Unknown signal: ignore
                            continue

                        elif isinstance(key.fileobj, Listener):
                            packet = key.fileobj.recv_request()

                            # Update stats
                            message_count += 1

                            # Create the callback
                            callback, error_callback = create_handler_callbacks(key.fileobj, packet.message_id)

                            # Dispatch
                            pool.apply_async(handle_message, args=(packet,),
                                             callback=callback, error_callback=error_callback)

                except Exception as e:
                    # Catch-all exception handler
                    logger.exception("Caught unexpected exception {!r}".format(e))

                    now = time.monotonic()

                    # Add new exception time to the history
                    exception_history.append(now)

                    # Remove exceptions outside the window from the history
                    cutoff = now - config.exception_window
                    while exception_history and exception_history[0] < cutoff:
                        exception_history.pop(0)

                    # Did we receive too many exceptions shortly after each other?
                    if len(exception_history) > config.max_exceptions:
                        logger.critical("Received more than {} exceptions in {} seconds, "
                                        "exiting".format(config.max_exceptions, config.exception_window))
                        running = False
                        stopping = True

            pool.close()
            pool.join()

        # Regain root so we can delete the PID file
        restore_privileges()
        try:
            if pid_filename:
                os.unlink(pid_filename)
                logger.info("Removing PID-file {}".format(pid_filename))
        except OSError:
            pass

    logger.info("Shutting down Python DHCPv6 server v{}".format(dhcpkit.__version__))

    return 0
Example #5
0
async def mp_new_proc(
        name: str,
        actor_nursery: 'ActorNursery',  # type: ignore  # noqa
        subactor: Actor,
        errors: Dict[Tuple[str, str], Exception],
        # passed through to actor main
        bind_addr: Tuple[str, int],
        parent_addr: Tuple[str, int],
        _runtime_vars: Dict[str, Any],  # serialized and sent to _child
        *,
        use_trio_run_in_process: bool = False,
        task_status: TaskStatus[Portal] = trio.TASK_STATUS_IGNORED) -> None:
    async with trio.open_nursery() as nursery:
        assert _ctx
        start_method = _ctx.get_start_method()
        if start_method == 'forkserver':
            # XXX do our hackery on the stdlib to avoid multiple
            # forkservers (one at each subproc layer).
            fs = forkserver._forkserver
            curr_actor = current_actor()
            if is_main_process() and not curr_actor._forkserver_info:
                # if we're the "main" process start the forkserver
                # only once and pass its ipc info to downstream
                # children
                # forkserver.set_forkserver_preload(enable_modules)
                forkserver.ensure_running()
                fs_info = (
                    fs._forkserver_address,
                    fs._forkserver_alive_fd,
                    getattr(fs, '_forkserver_pid', None),
                    getattr(resource_tracker._resource_tracker, '_pid', None),
                    resource_tracker._resource_tracker._fd,
                )
            else:
                assert curr_actor._forkserver_info
                fs_info = (
                    fs._forkserver_address,
                    fs._forkserver_alive_fd,
                    fs._forkserver_pid,
                    resource_tracker._resource_tracker._pid,
                    resource_tracker._resource_tracker._fd,
                ) = curr_actor._forkserver_info
        else:
            fs_info = (None, None, None, None, None)

        proc: mp.Process = _ctx.Process(  # type: ignore
            target=_mp_main,
            args=(
                subactor,
                bind_addr,
                fs_info,
                start_method,
                parent_addr,
            ),
            # daemon=True,
            name=name,
        )
        # `multiprocessing` only (since no async interface):
        # register the process before start in case we get a cancel
        # request before the actor has fully spawned - then we can wait
        # for it to fully come up before sending a cancel request
        actor_nursery._children[subactor.uid] = (subactor, proc, None)

        proc.start()
        if not proc.is_alive():
            raise ActorFailure("Couldn't start sub-actor?")

        log.info(f"Started {proc}")

        try:
            # wait for actor to spawn and connect back to us
            # channel should have handshake completed by the
            # local actor by the time we get a ref to it
            event, chan = await actor_nursery._actor.wait_for_peer(subactor.uid
                                                                   )
            portal = Portal(chan)
            actor_nursery._children[subactor.uid] = (subactor, proc, portal)

            # unblock parent task
            task_status.started(portal)

            # wait for ``ActorNursery`` block to signal that
            # subprocesses can be waited upon.
            # This is required to ensure synchronization
            # with user code that may want to manually await results
            # from nursery spawned sub-actors. We don't want the
            # containing nurseries here to collect results or error
            # while user code is still doing it's thing. Only after the
            # nursery block closes do we allow subactor results to be
            # awaited and reported upwards to the supervisor.
            await actor_nursery._join_procs.wait()

        finally:
            # XXX: in the case we were cancelled before the sub-proc
            # registered itself back we must be sure to try and clean
            # any process we may have started.

            reaping_cancelled: bool = False
            cancel_scope: Optional[trio.CancelScope] = None
            cancel_exc: Optional[trio.Cancelled] = None

            if portal in actor_nursery._cancel_after_result_on_exit:
                try:
                    # async with trio.open_nursery() as n:
                    # n.cancel_scope.shield = True
                    cancel_scope = await nursery.start(cancel_on_completion,
                                                       portal, subactor,
                                                       errors)
                except trio.Cancelled as err:
                    cancel_exc = err

                    # if the reaping task was cancelled we may have hit
                    # a race where the subproc disconnected before we
                    # could send it a message to cancel (classic 2 generals)
                    # in that case, wait shortly then kill the process.
                    reaping_cancelled = True

                    if proc.is_alive():
                        with trio.move_on_after(0.1) as cs:
                            cs.shield = True
                            await proc_waiter(proc)

                        if cs.cancelled_caught:
                            proc.terminate()

            if not reaping_cancelled and proc.is_alive():
                await proc_waiter(proc)

            # TODO: timeout block here?
            proc.join()

            log.debug(f"Joined {proc}")
            # pop child entry to indicate we are no longer managing subactor
            subactor, proc, portal = actor_nursery._children.pop(subactor.uid)

            # cancel result waiter that may have been spawned in
            # tandem if not done already
            if cancel_scope:
                log.warning("Cancelling existing result waiter task for "
                            f"{subactor.uid}")
                cancel_scope.cancel()

            elif reaping_cancelled:  # let the cancellation bubble up
                assert cancel_exc
                raise cancel_exc
Example #6
0
utmain = sys.modules['__main__']
if utmain.__package__ == "unittest" and utmain.__spec__ is None:
    from collections import namedtuple
    ModuleSpec = namedtuple("ModuleSpec", ["name"])
    utmain.__spec__ = ModuleSpec("unittest.__main__")
    del ModuleSpec
del utmain

if "grpc" in sys.modules:
    # use lazy_grpc as in bblfsh_roles.py if you really need it above
    raise RuntimeError("grpc may not be imported before fork()")
if multiprocessing.get_start_method() != "forkserver":
    try:
        multiprocessing.set_start_method("forkserver", force=True)
        forkserver.ensure_running()
    except ValueError:
        multiprocessing.set_start_method("spawn", force=True)
    except RuntimeError:
        raise RuntimeError(
            "multiprocessing start method is already set to \"%s\"" %
            multiprocessing.get_start_method()) from None


def setup():
    setup_logging("INFO")
    global ENRY
    if ENRY is not None:
        return
    ENRY = os.path.join(tempfile.mkdtemp(), "enry")
    if os.path.isfile("enry"):
Example #7
0
async def new_proc(
        name: str,
        actor_nursery: 'ActorNursery',  # type: ignore
        subactor: Actor,
        errors: Dict[Tuple[str, str], Exception],
        # passed through to actor main
        bind_addr: Tuple[str, int],
        parent_addr: Tuple[str, int],
        use_trio_run_in_process: bool = False,
        task_status: TaskStatus[Portal] = trio.TASK_STATUS_IGNORED) -> None:
    """Create a new ``multiprocessing.Process`` using the
    spawn method as configured using ``try_set_start_method()``.
    """
    cancel_scope = None

    # mark the new actor with the global spawn method
    subactor._spawn_method = _spawn_method

    async with trio.open_nursery() as nursery:
        if use_trio_run_in_process or _spawn_method == 'trio_run_in_process':
            # trio_run_in_process
            async with trio_run_in_process.open_in_process(
                    subactor._trip_main,
                    bind_addr,
                    parent_addr,
            ) as proc:
                log.info(f"Started {proc}")

                # wait for actor to spawn and connect back to us
                # channel should have handshake completed by the
                # local actor by the time we get a ref to it
                event, chan = await actor_nursery._actor.wait_for_peer(
                    subactor.uid)
                portal = Portal(chan)
                actor_nursery._children[subactor.uid] = (subactor, proc,
                                                         portal)
                task_status.started(portal)

                # wait for ActorNursery.wait() to be called
                await actor_nursery._join_procs.wait()

                if portal in actor_nursery._cancel_after_result_on_exit:
                    cancel_scope = await nursery.start(cancel_on_completion,
                                                       portal, subactor,
                                                       errors)

                # TRIP blocks here until process is complete
        else:
            # `multiprocessing`
            assert _ctx
            start_method = _ctx.get_start_method()
            if start_method == 'forkserver':
                # XXX do our hackery on the stdlib to avoid multiple
                # forkservers (one at each subproc layer).
                fs = forkserver._forkserver
                curr_actor = current_actor()
                if is_main_process() and not curr_actor._forkserver_info:
                    # if we're the "main" process start the forkserver
                    # only once and pass its ipc info to downstream
                    # children
                    # forkserver.set_forkserver_preload(rpc_module_paths)
                    forkserver.ensure_running()
                    fs_info = (
                        fs._forkserver_address,
                        fs._forkserver_alive_fd,
                        getattr(fs, '_forkserver_pid', None),
                        getattr(resource_tracker._resource_tracker, '_pid',
                                None),
                        resource_tracker._resource_tracker._fd,
                    )
                else:
                    assert curr_actor._forkserver_info
                    fs_info = (
                        fs._forkserver_address,
                        fs._forkserver_alive_fd,
                        fs._forkserver_pid,
                        resource_tracker._resource_tracker._pid,
                        resource_tracker._resource_tracker._fd,
                    ) = curr_actor._forkserver_info
            else:
                fs_info = (None, None, None, None, None)

            proc = _ctx.Process(  # type: ignore
                target=subactor._mp_main,
                args=(bind_addr, fs_info, start_method, parent_addr),
                # daemon=True,
                name=name,
            )
            # `multiprocessing` only (since no async interface):
            # register the process before start in case we get a cancel
            # request before the actor has fully spawned - then we can wait
            # for it to fully come up before sending a cancel request
            actor_nursery._children[subactor.uid] = (subactor, proc, None)

            proc.start()
            if not proc.is_alive():
                raise ActorFailure("Couldn't start sub-actor?")

            log.info(f"Started {proc}")

            # wait for actor to spawn and connect back to us
            # channel should have handshake completed by the
            # local actor by the time we get a ref to it
            event, chan = await actor_nursery._actor.wait_for_peer(subactor.uid
                                                                   )
            portal = Portal(chan)
            actor_nursery._children[subactor.uid] = (subactor, proc, portal)

            # unblock parent task
            task_status.started(portal)

            # wait for ``ActorNursery`` block to signal that
            # subprocesses can be waited upon.
            # This is required to ensure synchronization
            # with user code that may want to manually await results
            # from nursery spawned sub-actors. We don't want the
            # containing nurseries here to collect results or error
            # while user code is still doing it's thing. Only after the
            # nursery block closes do we allow subactor results to be
            # awaited and reported upwards to the supervisor.
            await actor_nursery._join_procs.wait()

            if portal in actor_nursery._cancel_after_result_on_exit:
                cancel_scope = await nursery.start(cancel_on_completion,
                                                   portal, subactor, errors)

            # TODO: timeout block here?
            if proc.is_alive():
                await proc_waiter(proc)
            proc.join()

        log.debug(f"Joined {proc}")
        # pop child entry to indicate we are no longer managing this subactor
        subactor, proc, portal = actor_nursery._children.pop(subactor.uid)
        # cancel result waiter that may have been spawned in
        # tandem if not done already
        if cancel_scope:
            log.warning(
                f"Cancelling existing result waiter task for {subactor.uid}")
            cancel_scope.cancel()