class WidgetHandler(FileSystemEventHandler): def __init__(self, *args, **kwargs): super(WidgetHandler, self).__init__(*args, **kwargs) self.context = zmq.Context() self.socket = self.context.socket(zmq.XREQ) self.socket.linger = 1000 self.socket.connect("tcp://127.0.0.1:%s" % port) self.session = Session(key=session_key.encode()) def on_any_event(self, event): msg = self.msg(event) self.session.send(self.socket, msg) def msg(self, event): return { "buffers": [], "channel": "shell", "metadata": {}, "parent_header": {}, "header": { "msg_id": str(uuid4()), "username": "******", "session": session, "msg_type": "comm_msg", "version": "5.0" }, "content": { "comm_id": model_id, "data": { "method": "custom", "content": { "event": event.event_type, "is_directory": event.is_directory, "src_path": event.src_path, "dest_path": getattr(event, "dest_path", None) } }, }, }
class IPEngine(BaseParallelApplication): name = 'ipengine' description = _description examples = _examples classes = List([ZMQInteractiveShell, ProfileDir, Session, Kernel]) _deprecated_classes = ["EngineFactory", "IPEngineApp"] enable_nanny = Bool( True, config=True, help="""Enable the nanny process. The nanny process enables remote signaling of single engines and more responsive notification of engine shutdown. .. versionadded:: 7.0 """, ) startup_script = Unicode('', config=True, help='specify a script to be run at startup') startup_command = Unicode('', config=True, help='specify a command to be run at startup') url_file = Unicode( '', config=True, help= """The full location of the file containing the connection information for the controller. If this is not given, the file must be in the security directory of the cluster directory. This location is resolved using the `profile` or `profile_dir` options.""", ) wait_for_url_file = Float( 10, config=True, help="""The maximum number of seconds to wait for url_file to exist. This is useful for batch-systems and shared-filesystems where the controller and engine are started at the same time and it may take a moment for the controller to write the connector files.""", ) url_file_name = Unicode('ipcontroller-engine.json', config=True) connection_info_env = Unicode() @default("connection_info_env") def _default_connection_file_env(self): return os.environ.get("IPP_CONNECTION_INFO", "") @observe('cluster_id') def _cluster_id_changed(self, change): if change['new']: base = 'ipcontroller-{}'.format(change['new']) else: base = 'ipcontroller' self.url_file_name = "%s-engine.json" % base log_url = Unicode( '', config=True, help="""The URL for the iploggerapp instance, for forwarding logging to a central location.""", ) registration_url = Unicode( config=True, help="""Override the registration URL""", ) out_stream_factory = Type( 'ipykernel.iostream.OutStream', config=True, help="""The OutStream for handling stdout/err. Typically 'ipykernel.iostream.OutStream'""", ) display_hook_factory = Type( 'ipykernel.displayhook.ZMQDisplayHook', config=True, help="""The class for handling displayhook. Typically 'ipykernel.displayhook.ZMQDisplayHook'""", ) location = Unicode( config=True, help="""The location (an IP address) of the controller. This is used for disambiguating URLs, to determine whether loopback should be used to connect or the public address.""", ) timeout = Float( 5.0, config=True, help="""The time (in seconds) to wait for the Controller to respond to registration requests before giving up.""", ) max_heartbeat_misses = Integer( 50, config=True, help="""The maximum number of times a check for the heartbeat ping of a controller can be missed before shutting down the engine. If set to 0, the check is disabled.""", ) sshserver = Unicode( config=True, help= """The SSH server to use for tunneling connections to the Controller.""", ) sshkey = Unicode( config=True, help= """The SSH private key file to use when tunneling connections to the Controller.""", ) paramiko = Bool( sys.platform == 'win32', config=True, help="""Whether to use paramiko instead of openssh for tunnels.""", ) use_mpi = Bool( False, config=True, help="""Enable MPI integration. If set, MPI rank will be requested for my rank, and additionally `mpi_init` will be executed in the interactive shell. """, ) init_mpi = Unicode( DEFAULT_MPI_INIT, config=True, help="""Code to execute in the user namespace when initializing MPI""", ) mpi_registration_delay = Float( 0.02, config=True, help="""Per-engine delay for mpiexec-launched engines avoids flooding the controller with registrations, which can stall under heavy load. Default: .02 (50 engines/sec, or 3000 engines/minute) """, ) # not configurable: user_ns = Dict() id = Integer( None, allow_none=True, config=True, help="""Request this engine ID. If run in MPI, will use the MPI rank. Otherwise, let the Hub decide what our rank should be. """, ) @default('id') def _id_default(self): if not self.use_mpi: return None from mpi4py import MPI if MPI.COMM_WORLD.size > 1: self.log.debug("MPI rank = %i", MPI.COMM_WORLD.rank) return MPI.COMM_WORLD.rank registrar = Instance('zmq.eventloop.zmqstream.ZMQStream', allow_none=True) kernel = Instance(Kernel, allow_none=True) hb_check_period = Integer() # States for the heartbeat monitoring # Initial values for monitored and pinged must satisfy "monitored > pinged == False" so that # during the first check no "missed" ping is reported. Must be floats for Python 3 compatibility. _hb_last_pinged = 0.0 _hb_last_monitored = 0.0 _hb_missed_beats = 0 # The zmq Stream which receives the pings from the Heart _hb_listener = None bident = Bytes() ident = Unicode() @default("ident") def _default_ident(self): return self.session.session @default("bident") def _default_bident(self): return self.ident.encode("utf8") @observe("ident") def _ident_changed(self, change): self.bident = self._default_bident() using_ssh = Bool(False) context = Instance(zmq.Context) @default("context") def _default_context(self): return zmq.Context.instance() # an IPKernelApp instance, used to setup listening for shell frontends kernel_app = Instance(IPKernelApp, allow_none=True) aliases = Dict(aliases) flags = Dict(flags) curve_serverkey = Bytes( config=True, help="The controller's public key for CURVE security") curve_secretkey = Bytes( config=True, help="""The engine's secret key for CURVE security. Usually autogenerated on launch.""", ) curve_publickey = Bytes( config=True, help="""The engine's public key for CURVE security. Usually autogenerated on launch.""", ) @default("curve_serverkey") def _default_curve_serverkey(self): return os.environ.get("IPP_CURVE_SERVERKEY", "").encode("ascii") @default("curve_secretkey") def _default_curve_secretkey(self): return os.environ.get("IPP_CURVE_SECRETKEY", "").encode("ascii") @default("curve_publickey") def _default_curve_publickey(self): return os.environ.get("IPP_CURVE_PUBLICKEY", "").encode("ascii") @validate("curve_publickey", "curve_secretkey", "curve_serverkey") def _cast_bytes(self, proposal): if isinstance(proposal.value, str): return proposal.value.encode("ascii") return proposal.value def _ensure_curve_keypair(self): if not self.curve_secretkey or not self.curve_publickey: self.log.info("Generating new CURVE credentials") self.curve_publickey, self.curve_secretkey = zmq.curve_keypair() def find_connection_file(self): """Set the url file. Here we don't try to actually see if it exists for is valid as that is handled by the connection logic. """ # Find the actual ipcontroller-engine.json connection file if not self.url_file: self.url_file = os.path.join(self.profile_dir.security_dir, self.url_file_name) def load_connection_file(self): """load config from a JSON connector file, at a *lower* priority than command-line/config files. Same content can be specified in $IPP_CONNECTION_INFO env """ config = self.config if self.connection_info_env: self.log.info("Loading connection info from $IPP_CONNECTION_INFO") d = json.loads(self.connection_info_env) else: self.log.info("Loading connection file %r", self.url_file) with open(self.url_file) as f: d = json.load(f) # allow hand-override of location for disambiguation # and ssh-server if 'IPEngine.location' not in self.cli_config: self.location = d['location'] if 'ssh' in d and not self.sshserver: self.sshserver = d.get("ssh") proto, ip = d['interface'].split('://') ip = disambiguate_ip_address(ip, self.location) d['interface'] = f'{proto}://{ip}' if d.get('curve_serverkey'): # connection file takes precedence over env, if present and defined self.curve_serverkey = d['curve_serverkey'].encode('ascii') if self.curve_serverkey: self.log.info("Using CurveZMQ security") self._ensure_curve_keypair() else: self.log.warning("Not using CurveZMQ security") # DO NOT allow override of basic URLs, serialization, or key # JSON file takes top priority there if d.get('key') or 'key' not in config.Session: config.Session.key = d.get('key', '').encode('utf8') config.Session.signature_scheme = d['signature_scheme'] self.registration_url = f"{d['interface']}:{d['registration']}" config.Session.packer = d['pack'] config.Session.unpacker = d['unpack'] self.session = Session(parent=self) self.log.debug("Config changed:") self.log.debug("%r", config) self.connection_info = d def bind_kernel(self, **kwargs): """Promote engine to listening kernel, accessible to frontends.""" if self.kernel_app is not None: return self.log.info( "Opening ports for direct connections as an IPython kernel") if self.curve_serverkey: self.log.warning("Bound kernel does not support CURVE security") kernel = self.kernel kwargs.setdefault('config', self.config) kwargs.setdefault('log', self.log) kwargs.setdefault('profile_dir', self.profile_dir) kwargs.setdefault('session', self.session) app = self.kernel_app = IPKernelApp(**kwargs) # allow IPKernelApp.instance(): IPKernelApp._instance = app app.init_connection_file() # relevant contents of init_sockets: app.shell_port = app._bind_socket(kernel.shell_streams[0], app.shell_port) app.log.debug("shell ROUTER Channel on port: %i", app.shell_port) iopub_socket = kernel.iopub_socket # ipykernel 4.3 iopub_socket is an IOThread wrapper: if hasattr(iopub_socket, 'socket'): iopub_socket = iopub_socket.socket app.iopub_port = app._bind_socket(iopub_socket, app.iopub_port) app.log.debug("iopub PUB Channel on port: %i", app.iopub_port) kernel.stdin_socket = self.context.socket(zmq.ROUTER) app.stdin_port = app._bind_socket(kernel.stdin_socket, app.stdin_port) app.log.debug("stdin ROUTER Channel on port: %i", app.stdin_port) # start the heartbeat, and log connection info: app.init_heartbeat() app.log_connection_info() app.connection_dir = self.profile_dir.security_dir app.write_connection_file() @property def tunnel_mod(self): from zmq.ssh import tunnel return tunnel def init_connector(self): """construct connection function, which handles tunnels.""" self.using_ssh = bool(self.sshkey or self.sshserver) if self.sshkey and not self.sshserver: # We are using ssh directly to the controller, tunneling localhost to localhost self.sshserver = self.registration_url.split('://')[1].split( ':')[0] if self.using_ssh: if self.tunnel_mod.try_passwordless_ssh(self.sshserver, self.sshkey, self.paramiko): password = False else: password = getpass("SSH Password for %s: " % self.sshserver) else: password = False def connect(s, url, curve_serverkey=None): url = disambiguate_url(url, self.location) if curve_serverkey is None: curve_serverkey = self.curve_serverkey if curve_serverkey: s.setsockopt(zmq.CURVE_SERVERKEY, curve_serverkey) s.setsockopt(zmq.CURVE_SECRETKEY, self.curve_secretkey) s.setsockopt(zmq.CURVE_PUBLICKEY, self.curve_publickey) if self.using_ssh: self.log.debug("Tunneling connection to %s via %s", url, self.sshserver) return self.tunnel_mod.tunnel_connection( s, url, self.sshserver, keyfile=self.sshkey, paramiko=self.paramiko, password=password, ) else: return s.connect(url) def maybe_tunnel(url): """like connect, but don't complete the connection (for use by heartbeat)""" url = disambiguate_url(url, self.location) if self.using_ssh: self.log.debug("Tunneling connection to %s via %s", url, self.sshserver) url, tunnelobj = self.tunnel_mod.open_tunnel( url, self.sshserver, keyfile=self.sshkey, paramiko=self.paramiko, password=password, ) return str(url) return connect, maybe_tunnel def register(self): """send the registration_request""" if self.use_mpi and self.id and self.id >= 100 and self.mpi_registration_delay: # Some launchres implement delay at the Launcher level, # but mpiexec must implement it int he engine process itself # delay based on our rank delay = self.id * self.mpi_registration_delay self.log.info( f"Delaying registration for {self.id} by {int(delay * 1000)}ms" ) time.sleep(delay) self.log.info("Registering with controller at %s" % self.registration_url) ctx = self.context connect, maybe_tunnel = self.init_connector() reg = ctx.socket(zmq.DEALER) reg.setsockopt(zmq.IDENTITY, self.bident) connect(reg, self.registration_url) self.registrar = zmqstream.ZMQStream(reg, self.loop) content = dict(uuid=self.ident) if self.id is not None: self.log.info("Requesting id: %i", self.id) content['id'] = self.id self._registration_completed = False self.registrar.on_recv( lambda msg: self.complete_registration(msg, connect, maybe_tunnel)) self.session.send(self.registrar, "registration_request", content=content) def _report_ping(self, msg): """Callback for when the heartmonitor.Heart receives a ping""" # self.log.debug("Received a ping: %s", msg) self._hb_last_pinged = time.time() def complete_registration(self, msg, connect, maybe_tunnel): try: self._complete_registration(msg, connect, maybe_tunnel) except Exception as e: self.log.critical(f"Error completing registration: {e}", exc_info=True) self.exit(255) def _complete_registration(self, msg, connect, maybe_tunnel): ctx = self.context loop = self.loop identity = self.bident idents, msg = self.session.feed_identities(msg) msg = self.session.deserialize(msg) content = msg['content'] info = self.connection_info def url(key): """get zmq url for given channel""" return str(info["interface"] + ":%i" % info[key]) def urls(key): return [f'{info["interface"]}:{port}' for port in info[key]] if content['status'] == 'ok': requested_id = self.id self.id = content['id'] if requested_id is not None and self.id != requested_id: self.log.warning( f"Did not get the requested id: {self.id} != {requested_id}" ) self.log.name = self.log.name.rsplit(".", 1)[0] + f".{self.id}" elif self.id is None: self.log.name += f".{self.id}" # create Shell Connections (MUX, Task, etc.): # select which broadcast endpoint to connect to # use rank % len(broadcast_leaves) broadcast_urls = urls('broadcast') broadcast_leaves = len(broadcast_urls) broadcast_index = self.id % len(broadcast_urls) broadcast_url = broadcast_urls[broadcast_index] shell_addrs = [url('mux'), url('task'), broadcast_url] self.log.info(f'Shell_addrs: {shell_addrs}') # Use only one shell stream for mux and tasks stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop) stream.setsockopt(zmq.IDENTITY, identity) # TODO: enable PROBE_ROUTER when schedulers can handle the empty message # stream.setsockopt(zmq.PROBE_ROUTER, 1) self.log.debug("Setting shell identity %r", identity) shell_streams = [stream] for addr in shell_addrs: self.log.info("Connecting shell to %s", addr) connect(stream, addr) # control stream: control_url = url('control') curve_serverkey = self.curve_serverkey if self.enable_nanny: nanny_url, self.nanny_pipe = self.start_nanny( control_url=control_url, ) control_url = nanny_url # nanny uses our curve_publickey, not the controller's publickey curve_serverkey = self.curve_publickey control_stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop) control_stream.setsockopt(zmq.IDENTITY, identity) connect(control_stream, control_url, curve_serverkey=curve_serverkey) # create iopub stream: iopub_addr = url('iopub') iopub_socket = ctx.socket(zmq.PUB) iopub_socket.SNDHWM = 0 iopub_socket.setsockopt(zmq.IDENTITY, identity) connect(iopub_socket, iopub_addr) try: from ipykernel.iostream import IOPubThread except ImportError: pass else: iopub_socket = IOPubThread(iopub_socket) iopub_socket.start() # disable history: self.config.HistoryManager.hist_file = ':memory:' # Redirect input streams and set a display hook. if self.out_stream_factory: sys.stdout = self.out_stream_factory(self.session, iopub_socket, 'stdout') sys.stdout.topic = f"engine.{self.id}.stdout".encode("ascii") sys.stderr = self.out_stream_factory(self.session, iopub_socket, 'stderr') sys.stderr.topic = f"engine.{self.id}.stderr".encode("ascii") # copied from ipykernel 6, which captures sys.__stderr__ at the FD-level if getattr(sys.stderr, "_original_stdstream_copy", None) is not None: for handler in self.log.handlers: if isinstance(handler, StreamHandler) and ( handler.stream.buffer.fileno() == 2): self.log.debug( "Seeing logger to stderr, rerouting to raw filedescriptor." ) handler.stream = TextIOWrapper( FileIO(sys.stderr._original_stdstream_copy, "w")) if self.display_hook_factory: sys.displayhook = self.display_hook_factory( self.session, iopub_socket) sys.displayhook.topic = f"engine.{self.id}.execute_result".encode( "ascii") # patch Session to always send engine uuid metadata original_send = self.session.send def send_with_metadata( stream, msg_or_type, content=None, parent=None, ident=None, buffers=None, track=False, header=None, metadata=None, **kwargs, ): """Ensure all messages set engine uuid metadata""" metadata = metadata or {} metadata.setdefault("engine", self.ident) return original_send( stream, msg_or_type, content=content, parent=parent, ident=ident, buffers=buffers, track=track, header=header, metadata=metadata, **kwargs, ) self.session.send = send_with_metadata self.kernel = Kernel.instance( parent=self, engine_id=self.id, ident=self.ident, session=self.session, control_stream=control_stream, shell_streams=shell_streams, iopub_socket=iopub_socket, loop=loop, user_ns=self.user_ns, log=self.log, ) self.kernel.shell.display_pub.topic = f"engine.{self.id}.displaypub".encode( "ascii") # FIXME: This is a hack until IPKernelApp and IPEngineApp can be fully merged self.init_signal() app = IPKernelApp(parent=self, shell=self.kernel.shell, kernel=self.kernel, log=self.log) if self.use_mpi and self.init_mpi: app.exec_lines.insert(0, self.init_mpi) app.init_profile_dir() app.init_code() self.kernel.start() else: self.log.fatal("Registration Failed: %s" % msg) raise Exception("Registration Failed: %s" % msg) self.start_heartbeat( maybe_tunnel(url('hb_ping')), maybe_tunnel(url('hb_pong')), content['hb_period'], identity, ) self.log.info("Completed registration with id %i" % self.id) self.loop.remove_timeout(self._abort_timeout) def start_nanny(self, control_url): self.log.info("Starting nanny") config = Config() config.Session = self.config.Session return start_nanny( engine_id=self.id, identity=self.bident, control_url=control_url, curve_serverkey=self.curve_serverkey, curve_secretkey=self.curve_secretkey, curve_publickey=self.curve_publickey, registration_url=self.registration_url, config=config, ) def start_heartbeat(self, hb_ping, hb_pong, hb_period, identity): """Start our heart beating""" hb_monitor = None if self.max_heartbeat_misses > 0: # Add a monitor socket which will record the last time a ping was seen mon = self.context.socket(zmq.SUB) if self.curve_serverkey: mon.setsockopt(zmq.CURVE_SERVER, 1) mon.setsockopt(zmq.CURVE_SECRETKEY, self.curve_secretkey) mport = mon.bind_to_random_port('tcp://%s' % localhost()) mon.setsockopt(zmq.SUBSCRIBE, b"") self._hb_listener = zmqstream.ZMQStream(mon, self.loop) self._hb_listener.on_recv(self._report_ping) hb_monitor = "tcp://%s:%i" % (localhost(), mport) heart = Heart( hb_ping, hb_pong, hb_monitor, heart_id=identity, curve_serverkey=self.curve_serverkey, curve_secretkey=self.curve_secretkey, curve_publickey=self.curve_publickey, ) heart.start() # periodically check the heartbeat pings of the controller # Should be started here and not in "start()" so that the right period can be taken # from the hubs HeartBeatMonitor.period if self.max_heartbeat_misses > 0: # Use a slightly bigger check period than the hub signal period to not warn unnecessary self.hb_check_period = hb_period + 500 self.log.info( "Starting to monitor the heartbeat signal from the hub every %i ms.", self.hb_check_period, ) self._hb_reporter = ioloop.PeriodicCallback( self._hb_monitor, self.hb_check_period) self._hb_reporter.start() else: self.log.info( "Monitoring of the heartbeat signal from the hub is not enabled." ) def abort(self): self.log.fatal("Registration timed out after %.1f seconds" % self.timeout) if "127." in self.registration_url: self.log.fatal(""" If the controller and engines are not on the same machine, you will have to instruct the controller to listen on an external IP (in ipcontroller_config.py): c.IPController.ip = '0.0.0.0' # for all interfaces, internal and external c.IPController.ip = '192.168.1.101' # or any interface that the engines can see or tunnel connections via ssh. """) self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id)) time.sleep(1) sys.exit(255) def _hb_monitor(self): """Callback to monitor the heartbeat from the controller""" self._hb_listener.flush() if self._hb_last_monitored > self._hb_last_pinged: self._hb_missed_beats += 1 self.log.warning( "No heartbeat in the last %s ms (%s time(s) in a row).", self.hb_check_period, self._hb_missed_beats, ) else: # self.log.debug("Heartbeat received (after missing %s beats).", self._hb_missed_beats) self._hb_missed_beats = 0 if self._hb_missed_beats >= self.max_heartbeat_misses: self.log.fatal( "Maximum number of heartbeats misses reached (%s times %s ms), shutting down.", self.max_heartbeat_misses, self.hb_check_period, ) self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id)) self.loop.stop() self._hb_last_monitored = time.time() def init_engine(self): # This is the working dir by now. sys.path.insert(0, '') config = self.config if not self.connection_info_env: self.find_connection_file() if self.wait_for_url_file and not os.path.exists(self.url_file): self.log.warning( f"Connection file {self.url_file!r} not found") self.log.warning( "Waiting up to %.1f seconds for it to arrive.", self.wait_for_url_file, ) tic = time.monotonic() while not os.path.exists(self.url_file) and ( time.monotonic() - tic < self.wait_for_url_file): # wait for url_file to exist, or until time limit time.sleep(0.1) if not os.path.exists(self.url_file): self.log.fatal( f"Fatal: connection file never arrived: {self.url_file}") self.exit(1) self.load_connection_file() exec_lines = [] for app in ('IPKernelApp', 'InteractiveShellApp'): if '%s.exec_lines' % app in config: exec_lines = config[app].exec_lines break exec_files = [] for app in ('IPKernelApp', 'InteractiveShellApp'): if '%s.exec_files' % app in config: exec_files = config[app].exec_files break config.IPKernelApp.exec_lines = exec_lines config.IPKernelApp.exec_files = exec_files if self.startup_script: exec_files.append(self.startup_script) if self.startup_command: exec_lines.append(self.startup_command) def forward_logging(self): if self.log_url: self.log.info("Forwarding logging to %s", self.log_url) context = self.context lsock = context.socket(zmq.PUB) lsock.connect(self.log_url) handler = EnginePUBHandler(self.engine, lsock) handler.setLevel(self.log_level) self.log.addHandler(handler) @catch_config_error def initialize(self, argv=None): super().initialize(argv) self.init_engine() self.forward_logging() def init_signal(self): signal.signal(signal.SIGINT, self._signal_sigint) signal.signal(signal.SIGTERM, self._signal_stop) def _signal_sigint(self, sig, frame): self.log.warning("Ignoring SIGINT. Terminate with SIGTERM.") def _signal_stop(self, sig, frame): self.log.critical(f"received signal {sig}, stopping") self.loop.add_callback_from_signal(self.loop.stop) def start(self): if self.id is not None: self.log.name += f".{self.id}" loop = self.loop def _start(): self.register() self._abort_timeout = loop.add_timeout(loop.time() + self.timeout, self.abort) self.loop.add_callback(_start) try: self.loop.start() except KeyboardInterrupt: self.log.critical("Engine Interrupted, shutting down...\n")
class KernelNanny: """Object for monitoring Must be child of engine Handles signal messages and watches Engine process for exiting """ def __init__( self, *, pid: int, engine_id: int, control_url: str, registration_url: str, identity: bytes, curve_serverkey: bytes, curve_publickey: bytes, curve_secretkey: bytes, config: Config, pipe, log_level: int = logging.INFO, ): self.pid = pid self.engine_id = engine_id self.parent_process = psutil.Process(self.pid) self.control_url = control_url self.registration_url = registration_url self.identity = identity self.curve_serverkey = curve_serverkey self.curve_publickey = curve_publickey self.curve_secretkey = curve_secretkey self.config = config self.pipe = pipe self.session = Session(config=self.config) self.log = local_logger(f"{self.__class__.__name__}.{engine_id}", log_level) self.log.propagate = False self.control_handlers = { "signal_request": self.signal_request, } self._finish_called = False def wait_for_parent_thread(self): """Wait for my parent to exit, then I'll notify the controller and shut down""" self.log.info(f"Nanny watching parent pid {self.pid}.") while True: try: exit_code = self.parent_process.wait(60) except psutil.TimeoutExpired: continue else: break self.log.critical(f"Parent {self.pid} exited with status {exit_code}.") self.loop.add_callback(self.finish) def pipe_handler(self, fd, events): self.log.debug(f"Pipe event {events}") self.loop.remove_handler(fd) try: fd.close() except BrokenPipeError: pass try: status = self.parent_process.wait(0) except psutil.TimeoutExpired: try: status = self.parent_process.status() except psutil.NoSuchProcess: status = "exited" self.log.critical( f"Pipe closed, parent {self.pid} has status: {status}") self.finish() def notify_exit(self): """Notify the Hub that our parent has exited""" self.log.info("Notifying Hub that our parent has shut down") s = self.context.socket(zmq.DEALER) # finite, nonzero LINGER to prevent hang without dropping message during exit s.LINGER = 3000 util.connect( s, self.registration_url, curve_serverkey=self.curve_serverkey, curve_secretkey=self.curve_secretkey, curve_publickey=self.curve_publickey, ) self.session.send(s, "unregistration_request", content={"id": self.engine_id}) s.close() def finish(self): """Prepare to exit and stop our event loop.""" if self._finish_called: return self._finish_called = True self.notify_exit() self.loop.add_callback(self.loop.stop) def dispatch_control(self, stream, raw_msg): """Dispatch message from the control scheduler If we have a handler registered""" try: idents, msg_frames = self.session.feed_identities(raw_msg) except Exception as e: self.log.error(f"Bad control message: {raw_msg}", exc_info=True) return try: msg = self.session.deserialize(msg_frames, content=True) except Exception: content = error.wrap_exception() self.log.error("Bad control message: %r", msg_frames, exc_info=True) return msg_type = msg['header']['msg_type'] if msg_type.endswith("_request"): reply_type = msg_type[-len("_request"):] else: reply_type = "error" self.log.debug(f"Client {idents[-1]} requested {msg_type}") handler = self.control_handlers.get(msg_type, None) if handler is None: # don't have an intercept handler, relay original message to parent self.log.debug(f"Relaying {msg_type} {msg['header']['msg_id']}") self.parent_stream.send_multipart(raw_msg) return try: content = handler(msg['content']) except Exception: content = error.wrap_exception() self.log.error("Error handling request: %r", msg_type, exc_info=True) self.session.send(stream, reply_type, ident=idents, content=content, parent=msg) def dispatch_parent(self, stream, raw_msg): """Relay messages from parent directly to control stream""" self.control_stream.send_multipart(raw_msg) # intercept message handlers def signal_request(self, content): """Handle a signal request: send signal to parent process""" sig = content['sig'] if isinstance(sig, str): sig = getattr(signal, sig) self.log.info(f"Sending signal {sig} to pid {self.pid}") # exception will be caught and wrapped by the caller self.parent_process.send_signal(sig) return {"status": "ok"} def start(self): self.log.info( f"Starting kernel nanny for engine {self.engine_id}, pid={self.pid}, nanny pid={os.getpid()}" ) self._watcher_thread = Thread(target=self.wait_for_parent_thread, name="WatchParent", daemon=True) self._watcher_thread.start() # ignore SIGINT sent to parent signal.signal(signal.SIGINT, signal.SIG_IGN) self.loop = IOLoop.current() self.context = zmq.Context() # set up control socket (connection to Scheduler) self.control_socket = self.context.socket(zmq.ROUTER) self.control_socket.identity = self.identity util.connect( self.control_socket, self.control_url, curve_serverkey=self.curve_serverkey, ) self.control_stream = ZMQStream(self.control_socket) self.control_stream.on_recv_stream(self.dispatch_control) # set up relay socket (connection to parent's control socket) self.parent_socket = self.context.socket(zmq.DEALER) if self.curve_secretkey: self.parent_socket.setsockopt(zmq.CURVE_SERVER, 1) self.parent_socket.setsockopt(zmq.CURVE_SECRETKEY, self.curve_secretkey) port = self.parent_socket.bind_to_random_port("tcp://127.0.0.1") # now that we've bound, pass port to parent via AsyncResult self.pipe.write(f"tcp://127.0.0.1:{port}\n") if not sys.platform.startswith("win"): # watch for the stdout pipe to close # as a signal that our parent is shutting down self.loop.add_handler(self.pipe, self.pipe_handler, IOLoop.READ | IOLoop.ERROR) self.parent_stream = ZMQStream(self.parent_socket) self.parent_stream.on_recv_stream(self.dispatch_parent) try: self.loop.start() finally: self.loop.close(all_fds=True) self.context.term() try: self.pipe.close() except BrokenPipeError: pass self.log.debug("exiting") @classmethod def main(cls, *args, **kwargs): """Main body function. Instantiates and starts a nanny. Args and keyword args passed to the constructor. Should be called in a subprocess. """ # start a new event loop for the forked process asyncio.set_event_loop(asyncio.new_event_loop()) IOLoop().make_current() self = cls(*args, **kwargs) self.start()