コード例 #1
0
class IPEngine(BaseParallelApplication):

    name = 'ipengine'
    description = _description
    examples = _examples
    classes = List([ZMQInteractiveShell, ProfileDir, Session, Kernel])
    _deprecated_classes = ["EngineFactory", "IPEngineApp"]

    enable_nanny = Bool(
        True,
        config=True,
        help="""Enable the nanny process.

    The nanny process enables remote signaling of single engines
    and more responsive notification of engine shutdown.

    .. versionadded:: 7.0

    """,
    )

    startup_script = Unicode('',
                             config=True,
                             help='specify a script to be run at startup')
    startup_command = Unicode('',
                              config=True,
                              help='specify a command to be run at startup')

    url_file = Unicode(
        '',
        config=True,
        help=
        """The full location of the file containing the connection information for
        the controller. If this is not given, the file must be in the
        security directory of the cluster directory.  This location is
        resolved using the `profile` or `profile_dir` options.""",
    )
    wait_for_url_file = Float(
        10,
        config=True,
        help="""The maximum number of seconds to wait for url_file to exist.
        This is useful for batch-systems and shared-filesystems where the
        controller and engine are started at the same time and it
        may take a moment for the controller to write the connector files.""",
    )

    url_file_name = Unicode('ipcontroller-engine.json', config=True)

    connection_info_env = Unicode()

    @default("connection_info_env")
    def _default_connection_file_env(self):
        return os.environ.get("IPP_CONNECTION_INFO", "")

    @observe('cluster_id')
    def _cluster_id_changed(self, change):
        if change['new']:
            base = 'ipcontroller-{}'.format(change['new'])
        else:
            base = 'ipcontroller'
        self.url_file_name = "%s-engine.json" % base

    log_url = Unicode(
        '',
        config=True,
        help="""The URL for the iploggerapp instance, for forwarding
        logging to a central location.""",
    )

    registration_url = Unicode(
        config=True,
        help="""Override the registration URL""",
    )
    out_stream_factory = Type(
        'ipykernel.iostream.OutStream',
        config=True,
        help="""The OutStream for handling stdout/err.
        Typically 'ipykernel.iostream.OutStream'""",
    )
    display_hook_factory = Type(
        'ipykernel.displayhook.ZMQDisplayHook',
        config=True,
        help="""The class for handling displayhook.
        Typically 'ipykernel.displayhook.ZMQDisplayHook'""",
    )
    location = Unicode(
        config=True,
        help="""The location (an IP address) of the controller.  This is
        used for disambiguating URLs, to determine whether
        loopback should be used to connect or the public address.""",
    )
    timeout = Float(
        5.0,
        config=True,
        help="""The time (in seconds) to wait for the Controller to respond
        to registration requests before giving up.""",
    )
    max_heartbeat_misses = Integer(
        50,
        config=True,
        help="""The maximum number of times a check for the heartbeat ping of a
        controller can be missed before shutting down the engine.

        If set to 0, the check is disabled.""",
    )
    sshserver = Unicode(
        config=True,
        help=
        """The SSH server to use for tunneling connections to the Controller.""",
    )
    sshkey = Unicode(
        config=True,
        help=
        """The SSH private key file to use when tunneling connections to the Controller.""",
    )
    paramiko = Bool(
        sys.platform == 'win32',
        config=True,
        help="""Whether to use paramiko instead of openssh for tunnels.""",
    )

    use_mpi = Bool(
        False,
        config=True,
        help="""Enable MPI integration.

        If set, MPI rank will be requested for my rank,
        and additionally `mpi_init` will be executed in the interactive shell.
        """,
    )
    init_mpi = Unicode(
        DEFAULT_MPI_INIT,
        config=True,
        help="""Code to execute in the user namespace when initializing MPI""",
    )
    mpi_registration_delay = Float(
        0.02,
        config=True,
        help="""Per-engine delay for mpiexec-launched engines

        avoids flooding the controller with registrations,
        which can stall under heavy load.

        Default: .02 (50 engines/sec, or 3000 engines/minute)
        """,
    )

    # not configurable:
    user_ns = Dict()
    id = Integer(
        None,
        allow_none=True,
        config=True,
        help="""Request this engine ID.

        If run in MPI, will use the MPI rank.
        Otherwise, let the Hub decide what our rank should be.
        """,
    )

    @default('id')
    def _id_default(self):
        if not self.use_mpi:
            return None
        from mpi4py import MPI

        if MPI.COMM_WORLD.size > 1:
            self.log.debug("MPI rank = %i", MPI.COMM_WORLD.rank)
            return MPI.COMM_WORLD.rank

    registrar = Instance('zmq.eventloop.zmqstream.ZMQStream', allow_none=True)
    kernel = Instance(Kernel, allow_none=True)
    hb_check_period = Integer()

    # States for the heartbeat monitoring
    # Initial values for monitored and pinged must satisfy "monitored > pinged == False" so that
    # during the first check no "missed" ping is reported. Must be floats for Python 3 compatibility.
    _hb_last_pinged = 0.0
    _hb_last_monitored = 0.0
    _hb_missed_beats = 0
    # The zmq Stream which receives the pings from the Heart
    _hb_listener = None

    bident = Bytes()
    ident = Unicode()

    @default("ident")
    def _default_ident(self):
        return self.session.session

    @default("bident")
    def _default_bident(self):
        return self.ident.encode("utf8")

    @observe("ident")
    def _ident_changed(self, change):
        self.bident = self._default_bident()

    using_ssh = Bool(False)

    context = Instance(zmq.Context)

    @default("context")
    def _default_context(self):
        return zmq.Context.instance()

    # an IPKernelApp instance, used to setup listening for shell frontends
    kernel_app = Instance(IPKernelApp, allow_none=True)

    aliases = Dict(aliases)
    flags = Dict(flags)

    curve_serverkey = Bytes(
        config=True, help="The controller's public key for CURVE security")
    curve_secretkey = Bytes(
        config=True,
        help="""The engine's secret key for CURVE security.
        Usually autogenerated on launch.""",
    )
    curve_publickey = Bytes(
        config=True,
        help="""The engine's public key for CURVE security.
        Usually autogenerated on launch.""",
    )

    @default("curve_serverkey")
    def _default_curve_serverkey(self):
        return os.environ.get("IPP_CURVE_SERVERKEY", "").encode("ascii")

    @default("curve_secretkey")
    def _default_curve_secretkey(self):
        return os.environ.get("IPP_CURVE_SECRETKEY", "").encode("ascii")

    @default("curve_publickey")
    def _default_curve_publickey(self):
        return os.environ.get("IPP_CURVE_PUBLICKEY", "").encode("ascii")

    @validate("curve_publickey", "curve_secretkey", "curve_serverkey")
    def _cast_bytes(self, proposal):
        if isinstance(proposal.value, str):
            return proposal.value.encode("ascii")
        return proposal.value

    def _ensure_curve_keypair(self):
        if not self.curve_secretkey or not self.curve_publickey:
            self.log.info("Generating new CURVE credentials")
            self.curve_publickey, self.curve_secretkey = zmq.curve_keypair()

    def find_connection_file(self):
        """Set the url file.

        Here we don't try to actually see if it exists for is valid as that
        is handled by the connection logic.
        """
        # Find the actual ipcontroller-engine.json connection file
        if not self.url_file:
            self.url_file = os.path.join(self.profile_dir.security_dir,
                                         self.url_file_name)

    def load_connection_file(self):
        """load config from a JSON connector file,
        at a *lower* priority than command-line/config files.

        Same content can be specified in $IPP_CONNECTION_INFO env
        """
        config = self.config

        if self.connection_info_env:
            self.log.info("Loading connection info from $IPP_CONNECTION_INFO")
            d = json.loads(self.connection_info_env)
        else:
            self.log.info("Loading connection file %r", self.url_file)
            with open(self.url_file) as f:
                d = json.load(f)

        # allow hand-override of location for disambiguation
        # and ssh-server
        if 'IPEngine.location' not in self.cli_config:
            self.location = d['location']
        if 'ssh' in d and not self.sshserver:
            self.sshserver = d.get("ssh")

        proto, ip = d['interface'].split('://')
        ip = disambiguate_ip_address(ip, self.location)
        d['interface'] = f'{proto}://{ip}'

        if d.get('curve_serverkey'):
            # connection file takes precedence over env, if present and defined
            self.curve_serverkey = d['curve_serverkey'].encode('ascii')
        if self.curve_serverkey:
            self.log.info("Using CurveZMQ security")
            self._ensure_curve_keypair()
        else:
            self.log.warning("Not using CurveZMQ security")

        # DO NOT allow override of basic URLs, serialization, or key
        # JSON file takes top priority there
        if d.get('key') or 'key' not in config.Session:
            config.Session.key = d.get('key', '').encode('utf8')
        config.Session.signature_scheme = d['signature_scheme']

        self.registration_url = f"{d['interface']}:{d['registration']}"

        config.Session.packer = d['pack']
        config.Session.unpacker = d['unpack']
        self.session = Session(parent=self)

        self.log.debug("Config changed:")
        self.log.debug("%r", config)
        self.connection_info = d

    def bind_kernel(self, **kwargs):
        """Promote engine to listening kernel, accessible to frontends."""
        if self.kernel_app is not None:
            return

        self.log.info(
            "Opening ports for direct connections as an IPython kernel")
        if self.curve_serverkey:
            self.log.warning("Bound kernel does not support CURVE security")

        kernel = self.kernel

        kwargs.setdefault('config', self.config)
        kwargs.setdefault('log', self.log)
        kwargs.setdefault('profile_dir', self.profile_dir)
        kwargs.setdefault('session', self.session)

        app = self.kernel_app = IPKernelApp(**kwargs)

        # allow IPKernelApp.instance():
        IPKernelApp._instance = app

        app.init_connection_file()
        # relevant contents of init_sockets:

        app.shell_port = app._bind_socket(kernel.shell_streams[0],
                                          app.shell_port)
        app.log.debug("shell ROUTER Channel on port: %i", app.shell_port)

        iopub_socket = kernel.iopub_socket
        # ipykernel 4.3 iopub_socket is an IOThread wrapper:
        if hasattr(iopub_socket, 'socket'):
            iopub_socket = iopub_socket.socket

        app.iopub_port = app._bind_socket(iopub_socket, app.iopub_port)
        app.log.debug("iopub PUB Channel on port: %i", app.iopub_port)

        kernel.stdin_socket = self.context.socket(zmq.ROUTER)
        app.stdin_port = app._bind_socket(kernel.stdin_socket, app.stdin_port)
        app.log.debug("stdin ROUTER Channel on port: %i", app.stdin_port)

        # start the heartbeat, and log connection info:

        app.init_heartbeat()

        app.log_connection_info()
        app.connection_dir = self.profile_dir.security_dir
        app.write_connection_file()

    @property
    def tunnel_mod(self):
        from zmq.ssh import tunnel

        return tunnel

    def init_connector(self):
        """construct connection function, which handles tunnels."""
        self.using_ssh = bool(self.sshkey or self.sshserver)

        if self.sshkey and not self.sshserver:
            # We are using ssh directly to the controller, tunneling localhost to localhost
            self.sshserver = self.registration_url.split('://')[1].split(
                ':')[0]

        if self.using_ssh:
            if self.tunnel_mod.try_passwordless_ssh(self.sshserver,
                                                    self.sshkey,
                                                    self.paramiko):
                password = False
            else:
                password = getpass("SSH Password for %s: " % self.sshserver)
        else:
            password = False

        def connect(s, url, curve_serverkey=None):
            url = disambiguate_url(url, self.location)
            if curve_serverkey is None:
                curve_serverkey = self.curve_serverkey
            if curve_serverkey:
                s.setsockopt(zmq.CURVE_SERVERKEY, curve_serverkey)
                s.setsockopt(zmq.CURVE_SECRETKEY, self.curve_secretkey)
                s.setsockopt(zmq.CURVE_PUBLICKEY, self.curve_publickey)

            if self.using_ssh:
                self.log.debug("Tunneling connection to %s via %s", url,
                               self.sshserver)
                return self.tunnel_mod.tunnel_connection(
                    s,
                    url,
                    self.sshserver,
                    keyfile=self.sshkey,
                    paramiko=self.paramiko,
                    password=password,
                )
            else:
                return s.connect(url)

        def maybe_tunnel(url):
            """like connect, but don't complete the connection (for use by heartbeat)"""
            url = disambiguate_url(url, self.location)
            if self.using_ssh:
                self.log.debug("Tunneling connection to %s via %s", url,
                               self.sshserver)
                url, tunnelobj = self.tunnel_mod.open_tunnel(
                    url,
                    self.sshserver,
                    keyfile=self.sshkey,
                    paramiko=self.paramiko,
                    password=password,
                )
            return str(url)

        return connect, maybe_tunnel

    def register(self):
        """send the registration_request"""
        if self.use_mpi and self.id and self.id >= 100 and self.mpi_registration_delay:
            # Some launchres implement delay at the Launcher level,
            # but mpiexec must implement it int he engine process itself
            # delay based on our rank

            delay = self.id * self.mpi_registration_delay
            self.log.info(
                f"Delaying registration for {self.id} by {int(delay * 1000)}ms"
            )
            time.sleep(delay)

        self.log.info("Registering with controller at %s" %
                      self.registration_url)
        ctx = self.context
        connect, maybe_tunnel = self.init_connector()
        reg = ctx.socket(zmq.DEALER)
        reg.setsockopt(zmq.IDENTITY, self.bident)
        connect(reg, self.registration_url)

        self.registrar = zmqstream.ZMQStream(reg, self.loop)

        content = dict(uuid=self.ident)
        if self.id is not None:
            self.log.info("Requesting id: %i", self.id)
            content['id'] = self.id
        self._registration_completed = False
        self.registrar.on_recv(
            lambda msg: self.complete_registration(msg, connect, maybe_tunnel))

        self.session.send(self.registrar,
                          "registration_request",
                          content=content)

    def _report_ping(self, msg):
        """Callback for when the heartmonitor.Heart receives a ping"""
        # self.log.debug("Received a ping: %s", msg)
        self._hb_last_pinged = time.time()

    def complete_registration(self, msg, connect, maybe_tunnel):
        try:
            self._complete_registration(msg, connect, maybe_tunnel)
        except Exception as e:
            self.log.critical(f"Error completing registration: {e}",
                              exc_info=True)
            self.exit(255)

    def _complete_registration(self, msg, connect, maybe_tunnel):
        ctx = self.context
        loop = self.loop
        identity = self.bident
        idents, msg = self.session.feed_identities(msg)
        msg = self.session.deserialize(msg)
        content = msg['content']
        info = self.connection_info

        def url(key):
            """get zmq url for given channel"""
            return str(info["interface"] + ":%i" % info[key])

        def urls(key):
            return [f'{info["interface"]}:{port}' for port in info[key]]

        if content['status'] == 'ok':
            requested_id = self.id
            self.id = content['id']
            if requested_id is not None and self.id != requested_id:
                self.log.warning(
                    f"Did not get the requested id: {self.id} != {requested_id}"
                )
                self.log.name = self.log.name.rsplit(".", 1)[0] + f".{self.id}"
            elif self.id is None:
                self.log.name += f".{self.id}"

            # create Shell Connections (MUX, Task, etc.):

            # select which broadcast endpoint to connect to
            # use rank % len(broadcast_leaves)
            broadcast_urls = urls('broadcast')
            broadcast_leaves = len(broadcast_urls)
            broadcast_index = self.id % len(broadcast_urls)
            broadcast_url = broadcast_urls[broadcast_index]

            shell_addrs = [url('mux'), url('task'), broadcast_url]
            self.log.info(f'Shell_addrs: {shell_addrs}')

            # Use only one shell stream for mux and tasks
            stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
            stream.setsockopt(zmq.IDENTITY, identity)
            # TODO: enable PROBE_ROUTER when schedulers can handle the empty message
            # stream.setsockopt(zmq.PROBE_ROUTER, 1)
            self.log.debug("Setting shell identity %r", identity)

            shell_streams = [stream]
            for addr in shell_addrs:
                self.log.info("Connecting shell to %s", addr)
                connect(stream, addr)

            # control stream:
            control_url = url('control')
            curve_serverkey = self.curve_serverkey
            if self.enable_nanny:
                nanny_url, self.nanny_pipe = self.start_nanny(
                    control_url=control_url, )
                control_url = nanny_url
                # nanny uses our curve_publickey, not the controller's publickey
                curve_serverkey = self.curve_publickey
            control_stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
            control_stream.setsockopt(zmq.IDENTITY, identity)
            connect(control_stream,
                    control_url,
                    curve_serverkey=curve_serverkey)

            # create iopub stream:
            iopub_addr = url('iopub')
            iopub_socket = ctx.socket(zmq.PUB)
            iopub_socket.SNDHWM = 0
            iopub_socket.setsockopt(zmq.IDENTITY, identity)
            connect(iopub_socket, iopub_addr)
            try:
                from ipykernel.iostream import IOPubThread
            except ImportError:
                pass
            else:
                iopub_socket = IOPubThread(iopub_socket)
                iopub_socket.start()

            # disable history:
            self.config.HistoryManager.hist_file = ':memory:'

            # Redirect input streams and set a display hook.
            if self.out_stream_factory:
                sys.stdout = self.out_stream_factory(self.session,
                                                     iopub_socket, 'stdout')
                sys.stdout.topic = f"engine.{self.id}.stdout".encode("ascii")
                sys.stderr = self.out_stream_factory(self.session,
                                                     iopub_socket, 'stderr')
                sys.stderr.topic = f"engine.{self.id}.stderr".encode("ascii")

                # copied from ipykernel 6, which captures sys.__stderr__ at the FD-level
                if getattr(sys.stderr, "_original_stdstream_copy",
                           None) is not None:
                    for handler in self.log.handlers:
                        if isinstance(handler, StreamHandler) and (
                                handler.stream.buffer.fileno() == 2):
                            self.log.debug(
                                "Seeing logger to stderr, rerouting to raw filedescriptor."
                            )

                            handler.stream = TextIOWrapper(
                                FileIO(sys.stderr._original_stdstream_copy,
                                       "w"))
            if self.display_hook_factory:
                sys.displayhook = self.display_hook_factory(
                    self.session, iopub_socket)
                sys.displayhook.topic = f"engine.{self.id}.execute_result".encode(
                    "ascii")

            # patch Session to always send engine uuid metadata
            original_send = self.session.send

            def send_with_metadata(
                stream,
                msg_or_type,
                content=None,
                parent=None,
                ident=None,
                buffers=None,
                track=False,
                header=None,
                metadata=None,
                **kwargs,
            ):
                """Ensure all messages set engine uuid metadata"""
                metadata = metadata or {}
                metadata.setdefault("engine", self.ident)
                return original_send(
                    stream,
                    msg_or_type,
                    content=content,
                    parent=parent,
                    ident=ident,
                    buffers=buffers,
                    track=track,
                    header=header,
                    metadata=metadata,
                    **kwargs,
                )

            self.session.send = send_with_metadata

            self.kernel = Kernel.instance(
                parent=self,
                engine_id=self.id,
                ident=self.ident,
                session=self.session,
                control_stream=control_stream,
                shell_streams=shell_streams,
                iopub_socket=iopub_socket,
                loop=loop,
                user_ns=self.user_ns,
                log=self.log,
            )

            self.kernel.shell.display_pub.topic = f"engine.{self.id}.displaypub".encode(
                "ascii")

            # FIXME: This is a hack until IPKernelApp and IPEngineApp can be fully merged
            self.init_signal()
            app = IPKernelApp(parent=self,
                              shell=self.kernel.shell,
                              kernel=self.kernel,
                              log=self.log)
            if self.use_mpi and self.init_mpi:
                app.exec_lines.insert(0, self.init_mpi)
            app.init_profile_dir()
            app.init_code()

            self.kernel.start()
        else:
            self.log.fatal("Registration Failed: %s" % msg)
            raise Exception("Registration Failed: %s" % msg)

        self.start_heartbeat(
            maybe_tunnel(url('hb_ping')),
            maybe_tunnel(url('hb_pong')),
            content['hb_period'],
            identity,
        )
        self.log.info("Completed registration with id %i" % self.id)
        self.loop.remove_timeout(self._abort_timeout)

    def start_nanny(self, control_url):
        self.log.info("Starting nanny")
        config = Config()
        config.Session = self.config.Session
        return start_nanny(
            engine_id=self.id,
            identity=self.bident,
            control_url=control_url,
            curve_serverkey=self.curve_serverkey,
            curve_secretkey=self.curve_secretkey,
            curve_publickey=self.curve_publickey,
            registration_url=self.registration_url,
            config=config,
        )

    def start_heartbeat(self, hb_ping, hb_pong, hb_period, identity):
        """Start our heart beating"""

        hb_monitor = None
        if self.max_heartbeat_misses > 0:
            # Add a monitor socket which will record the last time a ping was seen
            mon = self.context.socket(zmq.SUB)
            if self.curve_serverkey:
                mon.setsockopt(zmq.CURVE_SERVER, 1)
                mon.setsockopt(zmq.CURVE_SECRETKEY, self.curve_secretkey)
            mport = mon.bind_to_random_port('tcp://%s' % localhost())
            mon.setsockopt(zmq.SUBSCRIBE, b"")
            self._hb_listener = zmqstream.ZMQStream(mon, self.loop)
            self._hb_listener.on_recv(self._report_ping)

            hb_monitor = "tcp://%s:%i" % (localhost(), mport)

        heart = Heart(
            hb_ping,
            hb_pong,
            hb_monitor,
            heart_id=identity,
            curve_serverkey=self.curve_serverkey,
            curve_secretkey=self.curve_secretkey,
            curve_publickey=self.curve_publickey,
        )
        heart.start()

        # periodically check the heartbeat pings of the controller
        # Should be started here and not in "start()" so that the right period can be taken
        # from the hubs HeartBeatMonitor.period
        if self.max_heartbeat_misses > 0:
            # Use a slightly bigger check period than the hub signal period to not warn unnecessary
            self.hb_check_period = hb_period + 500
            self.log.info(
                "Starting to monitor the heartbeat signal from the hub every %i ms.",
                self.hb_check_period,
            )
            self._hb_reporter = ioloop.PeriodicCallback(
                self._hb_monitor, self.hb_check_period)
            self._hb_reporter.start()
        else:
            self.log.info(
                "Monitoring of the heartbeat signal from the hub is not enabled."
            )

    def abort(self):
        self.log.fatal("Registration timed out after %.1f seconds" %
                       self.timeout)
        if "127." in self.registration_url:
            self.log.fatal("""
            If the controller and engines are not on the same machine,
            you will have to instruct the controller to listen on an external IP (in ipcontroller_config.py):
                c.IPController.ip = '0.0.0.0' # for all interfaces, internal and external
                c.IPController.ip = '192.168.1.101' # or any interface that the engines can see
            or tunnel connections via ssh.
            """)
        self.session.send(self.registrar,
                          "unregistration_request",
                          content=dict(id=self.id))
        time.sleep(1)
        sys.exit(255)

    def _hb_monitor(self):
        """Callback to monitor the heartbeat from the controller"""
        self._hb_listener.flush()
        if self._hb_last_monitored > self._hb_last_pinged:
            self._hb_missed_beats += 1
            self.log.warning(
                "No heartbeat in the last %s ms (%s time(s) in a row).",
                self.hb_check_period,
                self._hb_missed_beats,
            )
        else:
            # self.log.debug("Heartbeat received (after missing %s beats).", self._hb_missed_beats)
            self._hb_missed_beats = 0

        if self._hb_missed_beats >= self.max_heartbeat_misses:
            self.log.fatal(
                "Maximum number of heartbeats misses reached (%s times %s ms), shutting down.",
                self.max_heartbeat_misses,
                self.hb_check_period,
            )
            self.session.send(self.registrar,
                              "unregistration_request",
                              content=dict(id=self.id))
            self.loop.stop()

        self._hb_last_monitored = time.time()

    def init_engine(self):
        # This is the working dir by now.
        sys.path.insert(0, '')
        config = self.config

        if not self.connection_info_env:
            self.find_connection_file()
            if self.wait_for_url_file and not os.path.exists(self.url_file):
                self.log.warning(
                    f"Connection file {self.url_file!r} not found")
                self.log.warning(
                    "Waiting up to %.1f seconds for it to arrive.",
                    self.wait_for_url_file,
                )
                tic = time.monotonic()
                while not os.path.exists(self.url_file) and (
                        time.monotonic() - tic < self.wait_for_url_file):
                    # wait for url_file to exist, or until time limit
                    time.sleep(0.1)

            if not os.path.exists(self.url_file):
                self.log.fatal(
                    f"Fatal: connection file never arrived: {self.url_file}")
                self.exit(1)

        self.load_connection_file()

        exec_lines = []
        for app in ('IPKernelApp', 'InteractiveShellApp'):
            if '%s.exec_lines' % app in config:
                exec_lines = config[app].exec_lines
                break

        exec_files = []
        for app in ('IPKernelApp', 'InteractiveShellApp'):
            if '%s.exec_files' % app in config:
                exec_files = config[app].exec_files
                break

        config.IPKernelApp.exec_lines = exec_lines
        config.IPKernelApp.exec_files = exec_files

        if self.startup_script:
            exec_files.append(self.startup_script)
        if self.startup_command:
            exec_lines.append(self.startup_command)

    def forward_logging(self):
        if self.log_url:
            self.log.info("Forwarding logging to %s", self.log_url)
            context = self.context
            lsock = context.socket(zmq.PUB)
            lsock.connect(self.log_url)
            handler = EnginePUBHandler(self.engine, lsock)
            handler.setLevel(self.log_level)
            self.log.addHandler(handler)

    @catch_config_error
    def initialize(self, argv=None):
        super().initialize(argv)
        self.init_engine()
        self.forward_logging()

    def init_signal(self):
        signal.signal(signal.SIGINT, self._signal_sigint)
        signal.signal(signal.SIGTERM, self._signal_stop)

    def _signal_sigint(self, sig, frame):
        self.log.warning("Ignoring SIGINT. Terminate with SIGTERM.")

    def _signal_stop(self, sig, frame):
        self.log.critical(f"received signal {sig}, stopping")
        self.loop.add_callback_from_signal(self.loop.stop)

    def start(self):
        if self.id is not None:
            self.log.name += f".{self.id}"
        loop = self.loop

        def _start():
            self.register()
            self._abort_timeout = loop.add_timeout(loop.time() + self.timeout,
                                                   self.abort)

        self.loop.add_callback(_start)
        try:
            self.loop.start()
        except KeyboardInterrupt:
            self.log.critical("Engine Interrupted, shutting down...\n")
コード例 #2
0
class KernelNanny:
    """Object for monitoring

    Must be child of engine

    Handles signal messages and watches Engine process for exiting
    """
    def __init__(
        self,
        *,
        pid: int,
        engine_id: int,
        control_url: str,
        registration_url: str,
        identity: bytes,
        curve_serverkey: bytes,
        curve_publickey: bytes,
        curve_secretkey: bytes,
        config: Config,
        pipe,
        log_level: int = logging.INFO,
    ):
        self.pid = pid
        self.engine_id = engine_id
        self.parent_process = psutil.Process(self.pid)
        self.control_url = control_url
        self.registration_url = registration_url
        self.identity = identity
        self.curve_serverkey = curve_serverkey
        self.curve_publickey = curve_publickey
        self.curve_secretkey = curve_secretkey
        self.config = config
        self.pipe = pipe
        self.session = Session(config=self.config)

        self.log = local_logger(f"{self.__class__.__name__}.{engine_id}",
                                log_level)
        self.log.propagate = False

        self.control_handlers = {
            "signal_request": self.signal_request,
        }
        self._finish_called = False

    def wait_for_parent_thread(self):
        """Wait for my parent to exit, then I'll notify the controller and shut down"""
        self.log.info(f"Nanny watching parent pid {self.pid}.")
        while True:
            try:
                exit_code = self.parent_process.wait(60)
            except psutil.TimeoutExpired:
                continue
            else:
                break
        self.log.critical(f"Parent {self.pid} exited with status {exit_code}.")
        self.loop.add_callback(self.finish)

    def pipe_handler(self, fd, events):
        self.log.debug(f"Pipe event {events}")
        self.loop.remove_handler(fd)
        try:
            fd.close()
        except BrokenPipeError:
            pass
        try:
            status = self.parent_process.wait(0)
        except psutil.TimeoutExpired:
            try:
                status = self.parent_process.status()
            except psutil.NoSuchProcess:
                status = "exited"

        self.log.critical(
            f"Pipe closed, parent {self.pid} has status: {status}")
        self.finish()

    def notify_exit(self):
        """Notify the Hub that our parent has exited"""
        self.log.info("Notifying Hub that our parent has shut down")
        s = self.context.socket(zmq.DEALER)
        # finite, nonzero LINGER to prevent hang without dropping message during exit
        s.LINGER = 3000
        util.connect(
            s,
            self.registration_url,
            curve_serverkey=self.curve_serverkey,
            curve_secretkey=self.curve_secretkey,
            curve_publickey=self.curve_publickey,
        )
        self.session.send(s,
                          "unregistration_request",
                          content={"id": self.engine_id})
        s.close()

    def finish(self):
        """Prepare to exit and stop our event loop."""
        if self._finish_called:
            return
        self._finish_called = True
        self.notify_exit()
        self.loop.add_callback(self.loop.stop)

    def dispatch_control(self, stream, raw_msg):
        """Dispatch message from the control scheduler

        If we have a handler registered"""
        try:
            idents, msg_frames = self.session.feed_identities(raw_msg)
        except Exception as e:
            self.log.error(f"Bad control message: {raw_msg}", exc_info=True)
            return

        try:
            msg = self.session.deserialize(msg_frames, content=True)
        except Exception:
            content = error.wrap_exception()
            self.log.error("Bad control message: %r",
                           msg_frames,
                           exc_info=True)
            return

        msg_type = msg['header']['msg_type']
        if msg_type.endswith("_request"):
            reply_type = msg_type[-len("_request"):]
        else:
            reply_type = "error"
        self.log.debug(f"Client {idents[-1]} requested {msg_type}")

        handler = self.control_handlers.get(msg_type, None)
        if handler is None:
            # don't have an intercept handler, relay original message to parent
            self.log.debug(f"Relaying {msg_type} {msg['header']['msg_id']}")
            self.parent_stream.send_multipart(raw_msg)
            return

        try:
            content = handler(msg['content'])
        except Exception:
            content = error.wrap_exception()
            self.log.error("Error handling request: %r",
                           msg_type,
                           exc_info=True)

        self.session.send(stream,
                          reply_type,
                          ident=idents,
                          content=content,
                          parent=msg)

    def dispatch_parent(self, stream, raw_msg):
        """Relay messages from parent directly to control stream"""
        self.control_stream.send_multipart(raw_msg)

    # intercept message handlers

    def signal_request(self, content):
        """Handle a signal request: send signal to parent process"""
        sig = content['sig']
        if isinstance(sig, str):
            sig = getattr(signal, sig)
        self.log.info(f"Sending signal {sig} to pid {self.pid}")
        # exception will be caught and wrapped by the caller
        self.parent_process.send_signal(sig)
        return {"status": "ok"}

    def start(self):
        self.log.info(
            f"Starting kernel nanny for engine {self.engine_id}, pid={self.pid}, nanny pid={os.getpid()}"
        )
        self._watcher_thread = Thread(target=self.wait_for_parent_thread,
                                      name="WatchParent",
                                      daemon=True)
        self._watcher_thread.start()
        # ignore SIGINT sent to parent
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        self.loop = IOLoop.current()
        self.context = zmq.Context()

        # set up control socket (connection to Scheduler)
        self.control_socket = self.context.socket(zmq.ROUTER)
        self.control_socket.identity = self.identity
        util.connect(
            self.control_socket,
            self.control_url,
            curve_serverkey=self.curve_serverkey,
        )
        self.control_stream = ZMQStream(self.control_socket)
        self.control_stream.on_recv_stream(self.dispatch_control)

        # set up relay socket (connection to parent's control socket)
        self.parent_socket = self.context.socket(zmq.DEALER)
        if self.curve_secretkey:
            self.parent_socket.setsockopt(zmq.CURVE_SERVER, 1)
            self.parent_socket.setsockopt(zmq.CURVE_SECRETKEY,
                                          self.curve_secretkey)

        port = self.parent_socket.bind_to_random_port("tcp://127.0.0.1")

        # now that we've bound, pass port to parent via AsyncResult
        self.pipe.write(f"tcp://127.0.0.1:{port}\n")
        if not sys.platform.startswith("win"):
            # watch for the stdout pipe to close
            # as a signal that our parent is shutting down
            self.loop.add_handler(self.pipe, self.pipe_handler,
                                  IOLoop.READ | IOLoop.ERROR)
        self.parent_stream = ZMQStream(self.parent_socket)
        self.parent_stream.on_recv_stream(self.dispatch_parent)
        try:
            self.loop.start()
        finally:
            self.loop.close(all_fds=True)
            self.context.term()
            try:
                self.pipe.close()
            except BrokenPipeError:
                pass
            self.log.debug("exiting")

    @classmethod
    def main(cls, *args, **kwargs):
        """Main body function.

        Instantiates and starts a nanny.

        Args and keyword args passed to the constructor.

        Should be called in a subprocess.
        """
        # start a new event loop for the forked process
        asyncio.set_event_loop(asyncio.new_event_loop())
        IOLoop().make_current()
        self = cls(*args, **kwargs)
        self.start()