class Authenticator(LoggingConfigurable): """A class for authentication. The API is one method, `authenticate`, a tornado gen.coroutine. """ db = Any() whitelist = Set(config=True, help="""Username whitelist. Use this to restrict which users can login. If empty, allow any user to attempt login. """ ) custom_html = Unicode('') @gen.coroutine def authenticate(self, handler, data): """Authenticate a user with login form data. This must be a tornado gen.coroutine. It must return the username on successful authentication, and return None on failed authentication. """ def add_user(self, user): """Add a new user By default, this just adds the user to the whitelist. Subclasses may do more extensive things, such as adding actual unix users. """ if self.whitelist: self.whitelist.add(user.name) def delete_user(self, user): """Triggered when a user is deleted. Removes the user from the whitelist. """ if user.name in self.whitelist: self.whitelist.remove(user.name) def login_url(self, base_url): """Override to register a custom login handler""" return url_path_join(base_url, 'login') def logout_url(self, base_url): """Override to register a custom logout handler""" return url_path_join(base_url, 'logout') def get_handlers(self, app): """Return any custom handlers the authenticator needs to register (e.g. for OAuth) """ return [ ('/login', LoginHandler), ]
class Containers(Configurable): lis = List(config=True) def _lis_default(self): return [-1] s = Set(config=True) def _s_default(self): return {'a'} d = Dict(config=True) def _d_default(self): return {'a': 'b'}
class EngineConnector(HasTraits): """A simple object for accessing the various zmq connections of an object. Attributes are: id (int): engine ID uuid (str): uuid (unused?) queue (str): identity of queue's XREQ socket registration (str): identity of registration XREQ socket heartbeat (str): identity of heartbeat XREQ socket """ id = Int(0) queue = Str() control = Str() registration = Str() heartbeat = Str() pending = Set()
class HeartMonitor(LoggingConfigurable): """A basic HeartMonitor class pingstream: a PUB stream pongstream: an XREP stream period: the period of the heartbeat in milliseconds""" period = Integer(3000, config=True, help='The frequency at which the Hub pings the engines for heartbeats ' '(in ms)', ) pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream') pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream') loop = Instance('zmq.eventloop.ioloop.IOLoop') def _loop_default(self): return ioloop.IOLoop.instance() # not settable: hearts=Set() responses=Set() on_probation=Set() last_ping=CFloat(0) _new_handlers = Set() _failure_handlers = Set() lifetime = CFloat(0) tic = CFloat(0) def __init__(self, **kwargs): super(HeartMonitor, self).__init__(**kwargs) self.pongstream.on_recv(self.handle_pong) def start(self): self.tic = time.time() self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop) self.caller.start() def add_new_heart_handler(self, handler): """add a new handler for new hearts""" self.log.debug("heartbeat::new_heart_handler: %s", handler) self._new_handlers.add(handler) def add_heart_failure_handler(self, handler): """add a new handler for heart failure""" self.log.debug("heartbeat::new heart failure handler: %s", handler) self._failure_handlers.add(handler) def beat(self): self.pongstream.flush() self.last_ping = self.lifetime toc = time.time() self.lifetime += toc-self.tic self.tic = toc self.log.debug("heartbeat::sending %s", self.lifetime) goodhearts = self.hearts.intersection(self.responses) missed_beats = self.hearts.difference(goodhearts) heartfailures = self.on_probation.intersection(missed_beats) newhearts = self.responses.difference(goodhearts) map(self.handle_new_heart, newhearts) map(self.handle_heart_failure, heartfailures) self.on_probation = missed_beats.intersection(self.hearts) self.responses = set() # print self.on_probation, self.hearts # self.log.debug("heartbeat::beat %.3f, %i beating hearts", self.lifetime, len(self.hearts)) self.pingstream.send(asbytes(str(self.lifetime))) # flush stream to force immediate socket send self.pingstream.flush() def handle_new_heart(self, heart): if self._new_handlers: for handler in self._new_handlers: handler(heart) else: self.log.info("heartbeat::yay, got new heart %s!", heart) self.hearts.add(heart) def handle_heart_failure(self, heart): if self._failure_handlers: for handler in self._failure_handlers: try: handler(heart) except Exception as e: self.log.error("heartbeat::Bad Handler! %s", handler, exc_info=True) pass else: self.log.info("heartbeat::Heart %s failed :(", heart) self.hearts.remove(heart) @log_errors def handle_pong(self, msg): "a heart just beat" current = asbytes(str(self.lifetime)) last = asbytes(str(self.last_ping)) if msg[1] == current: delta = time.time()-self.tic # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta)) self.responses.add(msg[0]) elif msg[1] == last: delta = time.time()-self.tic + (self.lifetime-self.last_ping) self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond", msg[0], 1000*delta) self.responses.add(msg[0]) else: self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)", msg[1], self.lifetime)
class TaskScheduler(SessionFactory): """Python TaskScheduler object. This is the simplest object that supports msg_id based DAG dependencies. *Only* task msg_ids are checked, not msg_ids of jobs submitted via the MUX queue. """ hwm = Integer(1, config=True, help="""specify the High Water Mark (HWM) for the downstream socket in the Task scheduler. This is the maximum number of allowed outstanding tasks on each engine. The default (1) means that only one task can be outstanding on each engine. Setting TaskScheduler.hwm=0 means there is no limit, and the engines continue to be assigned tasks while they are working, effectively hiding network latency behind computation, but can result in an imbalance of work when submitting many heterogenous tasks all at once. Any positive value greater than one is a compromise between the two. """) scheme_name = Enum( ('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'), 'leastload', config=True, allow_none=False, help="""select the task scheduler scheme [default: Python LRU] Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'""" ) def _scheme_name_changed(self, old, new): self.log.debug("Using scheme %r" % new) self.scheme = globals()[new] # input arguments: scheme = Instance(FunctionType) # function for determining the destination def _scheme_default(self): return leastload client_stream = Instance(zmqstream.ZMQStream) # client-facing stream engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream # internals: graph = Dict() # dict by msg_id of [ msg_ids that depend on key ] retries = Dict() # dict by msg_id of retries remaining (non-neg ints) # waiting = List() # list of msg_ids ready to run, but haven't due to HWM depending = Dict() # dict by msg_id of Jobs pending = Dict() # dict by engine_uuid of submitted tasks completed = Dict() # dict by engine_uuid of completed tasks failed = Dict() # dict by engine_uuid of failed tasks destinations = Dict( ) # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed) clients = Dict() # dict by msg_id for who submitted the task targets = List() # list of target IDENTs loads = List() # list of engine loads # full = Set() # set of IDENTs that have HWM outstanding tasks all_completed = Set() # set of all completed tasks all_failed = Set() # set of all failed tasks all_done = Set() # set of all finished tasks=union(completed,failed) all_ids = Set() # set of all submitted task IDs auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback') ident = CBytes() # ZMQ identity. This should just be self.session.session # but ensure Bytes def _ident_default(self): return self.session.bsession def start(self): self.query_stream.on_recv(self.dispatch_query_reply) self.session.send(self.query_stream, "connection_request", {}) self.engine_stream.on_recv(self.dispatch_result, copy=False) self.client_stream.on_recv(self.dispatch_submission, copy=False) self._notification_handlers = dict( registration_notification=self._register_engine, unregistration_notification=self._unregister_engine) self.notifier_stream.on_recv(self.dispatch_notification) self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz self.auditor.start() self.log.info("Scheduler started [%s]" % self.scheme_name) def resume_receiving(self): """Resume accepting jobs.""" self.client_stream.on_recv(self.dispatch_submission, copy=False) def stop_receiving(self): """Stop accepting jobs while there are no engines. Leave them in the ZMQ queue.""" self.client_stream.on_recv(None) #----------------------------------------------------------------------- # [Un]Registration Handling #----------------------------------------------------------------------- def dispatch_query_reply(self, msg): """handle reply to our initial connection request""" try: idents, msg = self.session.feed_identities(msg) except ValueError: self.log.warn("task::Invalid Message: %r", msg) return try: msg = self.session.unserialize(msg) except ValueError: self.log.warn("task::Unauthorized message from: %r" % idents) return content = msg['content'] for uuid in content.get('engines', {}).values(): self._register_engine(cast_bytes(uuid)) @util.log_errors def dispatch_notification(self, msg): """dispatch register/unregister events.""" try: idents, msg = self.session.feed_identities(msg) except ValueError: self.log.warn("task::Invalid Message: %r", msg) return try: msg = self.session.unserialize(msg) except ValueError: self.log.warn("task::Unauthorized message from: %r" % idents) return msg_type = msg['header']['msg_type'] handler = self._notification_handlers.get(msg_type, None) if handler is None: self.log.error("Unhandled message type: %r" % msg_type) else: try: handler(cast_bytes(msg['content']['uuid'])) except Exception: self.log.error("task::Invalid notification msg: %r", msg, exc_info=True) def _register_engine(self, uid): """New engine with ident `uid` became available.""" # head of the line: self.targets.insert(0, uid) self.loads.insert(0, 0) # initialize sets self.completed[uid] = set() self.failed[uid] = set() self.pending[uid] = {} # rescan the graph: self.update_graph(None) def _unregister_engine(self, uid): """Existing engine with ident `uid` became unavailable.""" if len(self.targets) == 1: # this was our only engine pass # handle any potentially finished tasks: self.engine_stream.flush() # don't pop destinations, because they might be used later # map(self.destinations.pop, self.completed.pop(uid)) # map(self.destinations.pop, self.failed.pop(uid)) # prevent this engine from receiving work idx = self.targets.index(uid) self.targets.pop(idx) self.loads.pop(idx) # wait 5 seconds before cleaning up pending jobs, since the results might # still be incoming if self.pending[uid]: dc = ioloop.DelayedCallback( lambda: self.handle_stranded_tasks(uid), 5000, self.loop) dc.start() else: self.completed.pop(uid) self.failed.pop(uid) def handle_stranded_tasks(self, engine): """Deal with jobs resident in an engine that died.""" lost = self.pending[engine] for msg_id in lost.keys(): if msg_id not in self.pending[engine]: # prevent double-handling of messages continue raw_msg = lost[msg_id].raw_msg idents, msg = self.session.feed_identities(raw_msg, copy=False) parent = self.session.unpack(msg[1].bytes) idents = [engine, idents[0]] # build fake error reply try: raise error.EngineError( "Engine %r died while running task %r" % (engine, msg_id)) except: content = error.wrap_exception() # build fake metadata md = dict( status=u'error', engine=engine, date=datetime.now(), ) msg = self.session.msg('apply_reply', content, parent=parent, metadata=md) raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents)) # and dispatch it self.dispatch_result(raw_reply) # finally scrub completed/failed lists self.completed.pop(engine) self.failed.pop(engine) #----------------------------------------------------------------------- # Job Submission #----------------------------------------------------------------------- @util.log_errors def dispatch_submission(self, raw_msg): """Dispatch job submission to appropriate handlers.""" # ensure targets up to date: self.notifier_stream.flush() try: idents, msg = self.session.feed_identities(raw_msg, copy=False) msg = self.session.unserialize(msg, content=False, copy=False) except Exception: self.log.error("task::Invaid task msg: %r" % raw_msg, exc_info=True) return # send to monitor self.mon_stream.send_multipart([b'intask'] + raw_msg, copy=False) header = msg['header'] md = msg['metadata'] msg_id = header['msg_id'] self.all_ids.add(msg_id) # get targets as a set of bytes objects # from a list of unicode objects targets = md.get('targets', []) targets = map(cast_bytes, targets) targets = set(targets) retries = md.get('retries', 0) self.retries[msg_id] = retries # time dependencies after = md.get('after', None) if after: after = Dependency(after) if after.all: if after.success: after = Dependency( after.difference(self.all_completed), success=after.success, failure=after.failure, all=after.all, ) if after.failure: after = Dependency( after.difference(self.all_failed), success=after.success, failure=after.failure, all=after.all, ) if after.check(self.all_completed, self.all_failed): # recast as empty set, if `after` already met, # to prevent unnecessary set comparisons after = MET else: after = MET # location dependencies follow = Dependency(md.get('follow', [])) # turn timeouts into datetime objects: timeout = md.get('timeout', None) if timeout: # cast to float, because jsonlib returns floats as decimal.Decimal, # which timedelta does not accept timeout = datetime.now() + timedelta(0, float(timeout), 0) job = Job( msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg, header=header, targets=targets, after=after, follow=follow, timeout=timeout, metadata=md, ) # validate and reduce dependencies: for dep in after, follow: if not dep: # empty dependency continue # check valid: if msg_id in dep or dep.difference(self.all_ids): self.depending[msg_id] = job return self.fail_unreachable(msg_id, error.InvalidDependency) # check if unreachable: if dep.unreachable(self.all_completed, self.all_failed): self.depending[msg_id] = job return self.fail_unreachable(msg_id) if after.check(self.all_completed, self.all_failed): # time deps already met, try to run if not self.maybe_run(job): # can't run yet if msg_id not in self.all_failed: # could have failed as unreachable self.save_unmet(job) else: self.save_unmet(job) def audit_timeouts(self): """Audit all waiting tasks for expired timeouts.""" now = datetime.now() for msg_id in self.depending.keys(): # must recheck, in case one failure cascaded to another: if msg_id in self.depending: job = self.depending[msg_id] if job.timeout and job.timeout < now: self.fail_unreachable(msg_id, error.TaskTimeout) def fail_unreachable(self, msg_id, why=error.ImpossibleDependency): """a task has become unreachable, send a reply with an ImpossibleDependency error.""" if msg_id not in self.depending: self.log.error("msg %r already failed!", msg_id) return job = self.depending.pop(msg_id) for mid in job.dependents: if mid in self.graph: self.graph[mid].remove(msg_id) try: raise why() except: content = error.wrap_exception() self.all_done.add(msg_id) self.all_failed.add(msg_id) msg = self.session.send(self.client_stream, 'apply_reply', content, parent=job.header, ident=job.idents) self.session.send(self.mon_stream, msg, ident=[b'outtask'] + job.idents) self.update_graph(msg_id, success=False) def maybe_run(self, job): """check location dependencies, and run if they are met.""" msg_id = job.msg_id self.log.debug("Attempting to assign task %s", msg_id) if not self.targets: # no engines, definitely can't run return False if job.follow or job.targets or job.blacklist or self.hwm: # we need a can_run filter def can_run(idx): # check hwm if self.hwm and self.loads[idx] == self.hwm: return False target = self.targets[idx] # check blacklist if target in job.blacklist: return False # check targets if job.targets and target not in job.targets: return False # check follow return job.follow.check(self.completed[target], self.failed[target]) indices = filter(can_run, range(len(self.targets))) if not indices: # couldn't run if job.follow.all: # check follow for impossibility dests = set() relevant = set() if job.follow.success: relevant = self.all_completed if job.follow.failure: relevant = relevant.union(self.all_failed) for m in job.follow.intersection(relevant): dests.add(self.destinations[m]) if len(dests) > 1: self.depending[msg_id] = job self.fail_unreachable(msg_id) return False if job.targets: # check blacklist+targets for impossibility job.targets.difference_update(job.blacklist) if not job.targets or not job.targets.intersection( self.targets): self.depending[msg_id] = job self.fail_unreachable(msg_id) return False return False else: indices = None self.submit_task(job, indices) return True def save_unmet(self, job): """Save a message for later submission when its dependencies are met.""" msg_id = job.msg_id self.depending[msg_id] = job # track the ids in follow or after, but not those already finished for dep_id in job.after.union(job.follow).difference(self.all_done): if dep_id not in self.graph: self.graph[dep_id] = set() self.graph[dep_id].add(msg_id) def submit_task(self, job, indices=None): """Submit a task to any of a subset of our targets.""" if indices: loads = [self.loads[i] for i in indices] else: loads = self.loads idx = self.scheme(loads) if indices: idx = indices[idx] target = self.targets[idx] # print (target, map(str, msg[:3])) # send job to the engine self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False) self.engine_stream.send_multipart(job.raw_msg, copy=False) # update load self.add_job(idx) self.pending[target][job.msg_id] = job # notify Hub content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii')) self.session.send(self.mon_stream, 'task_destination', content=content, ident=[b'tracktask', self.ident]) #----------------------------------------------------------------------- # Result Handling #----------------------------------------------------------------------- @util.log_errors def dispatch_result(self, raw_msg): """dispatch method for result replies""" try: idents, msg = self.session.feed_identities(raw_msg, copy=False) msg = self.session.unserialize(msg, content=False, copy=False) engine = idents[0] try: idx = self.targets.index(engine) except ValueError: pass # skip load-update for dead engines else: self.finish_job(idx) except Exception: self.log.error("task::Invaid result: %r", raw_msg, exc_info=True) return md = msg['metadata'] parent = msg['parent_header'] if md.get('dependencies_met', True): success = (md['status'] == 'ok') msg_id = parent['msg_id'] retries = self.retries[msg_id] if not success and retries > 0: # failed self.retries[msg_id] = retries - 1 self.handle_unmet_dependency(idents, parent) else: del self.retries[msg_id] # relay to client and update graph self.handle_result(idents, parent, raw_msg, success) # send to Hub monitor self.mon_stream.send_multipart([b'outtask'] + raw_msg, copy=False) else: self.handle_unmet_dependency(idents, parent) def handle_result(self, idents, parent, raw_msg, success=True): """handle a real task result, either success or failure""" # first, relay result to client engine = idents[0] client = idents[1] # swap_ids for ROUTER-ROUTER mirror raw_msg[:2] = [client, engine] # print (map(str, raw_msg[:4])) self.client_stream.send_multipart(raw_msg, copy=False) # now, update our data structures msg_id = parent['msg_id'] self.pending[engine].pop(msg_id) if success: self.completed[engine].add(msg_id) self.all_completed.add(msg_id) else: self.failed[engine].add(msg_id) self.all_failed.add(msg_id) self.all_done.add(msg_id) self.destinations[msg_id] = engine self.update_graph(msg_id, success) def handle_unmet_dependency(self, idents, parent): """handle an unmet dependency""" engine = idents[0] msg_id = parent['msg_id'] job = self.pending[engine].pop(msg_id) job.blacklist.add(engine) if job.blacklist == job.targets: self.depending[msg_id] = job self.fail_unreachable(msg_id) elif not self.maybe_run(job): # resubmit failed if msg_id not in self.all_failed: # put it back in our dependency tree self.save_unmet(job) if self.hwm: try: idx = self.targets.index(engine) except ValueError: pass # skip load-update for dead engines else: if self.loads[idx] == self.hwm - 1: self.update_graph(None) def update_graph(self, dep_id=None, success=True): """dep_id just finished. Update our dependency graph and submit any jobs that just became runable. Called with dep_id=None to update entire graph for hwm, but without finishing a task. """ # print ("\n\n***********") # pprint (dep_id) # pprint (self.graph) # pprint (self.depending) # pprint (self.all_completed) # pprint (self.all_failed) # print ("\n\n***********\n\n") # update any jobs that depended on the dependency jobs = self.graph.pop(dep_id, []) # recheck *all* jobs if # a) we have HWM and an engine just become no longer full # or b) dep_id was given as None if dep_id is None or self.hwm and any( [load == self.hwm - 1 for load in self.loads]): jobs = self.depending.keys() for msg_id in sorted( jobs, key=lambda msg_id: self.depending[msg_id].timestamp): job = self.depending[msg_id] if job.after.unreachable(self.all_completed, self.all_failed)\ or job.follow.unreachable(self.all_completed, self.all_failed): self.fail_unreachable(msg_id) elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run if self.maybe_run(job): self.depending.pop(msg_id) for mid in job.dependents: if mid in self.graph: self.graph[mid].remove(msg_id) #---------------------------------------------------------------------- # methods to be overridden by subclasses #---------------------------------------------------------------------- def add_job(self, idx): """Called after self.targets[idx] just got the job with header. Override with subclasses. The default ordering is simple LRU. The default loads are the number of outstanding jobs.""" self.loads[idx] += 1 for lis in (self.targets, self.loads): lis.append(lis.pop(idx)) def finish_job(self, idx): """Called after self.targets[idx] just finished a job. Override with subclasses.""" self.loads[idx] -= 1
class Authenticator(LoggingConfigurable): """A class for authentication. The API is one method, `authenticate`, a tornado gen.coroutine. """ db = Any() whitelist = Set(config=True, help="""Username whitelist. Use this to restrict which users can login. If empty, allow any user to attempt login. """ ) custom_html = Unicode('', help="""HTML login form for custom handlers. Override in form-based custom authenticators that don't use username+password, or need custom branding. """ ) login_service = Unicode('', help="""Name of the login service for external login services (e.g. 'GitHub'). """ ) @gen.coroutine def authenticate(self, handler, data): """Authenticate a user with login form data. This must be a tornado gen.coroutine. It must return the username on successful authentication, and return None on failed authentication. """ def check_whitelist(self, user): """ Return True if the whitelist is empty or user is in the whitelist. """ # Parens aren't necessary here, but they make this easier to parse. return (not self.whitelist) or (user in self.whitelist) def add_user(self, user): """Add a new user By default, this just adds the user to the whitelist. Subclasses may do more extensive things, such as adding actual unix users. """ if self.whitelist: self.whitelist.add(user.name) def delete_user(self, user): """Triggered when a user is deleted. Removes the user from the whitelist. """ self.whitelist.discard(user.name) def login_url(self, base_url): """Override to register a custom login handler""" return url_path_join(base_url, 'login') def logout_url(self, base_url): """Override to register a custom logout handler""" return url_path_join(base_url, 'logout') def get_handlers(self, app): """Return any custom handlers the authenticator needs to register (e.g. for OAuth) """ return [ ('/login', LoginHandler), ]
class Session(Configurable): """Object for handling serialization and sending of messages. The Session object handles building messages and sending them with ZMQ sockets or ZMQStream objects. Objects can communicate with each other over the network via Session objects, and only need to work with the dict-based IPython message spec. The Session will handle serialization/deserialization, security, and metadata. Sessions support configurable serialiization via packer/unpacker traits, and signing with HMAC digests via the key/keyfile traits. Parameters ---------- debug : bool whether to trigger extra debugging statements packer/unpacker : str : 'json', 'pickle' or import_string importstrings for methods to serialize message parts. If just 'json' or 'pickle', predefined JSON and pickle packers will be used. Otherwise, the entire importstring must be used. The functions must accept at least valid JSON input, and output *bytes*. For example, to use msgpack: packer = 'msgpack.packb', unpacker='msgpack.unpackb' pack/unpack : callables You can also set the pack/unpack callables for serialization directly. session : bytes the ID of this Session object. The default is to generate a new UUID. username : unicode username added to message headers. The default is to ask the OS. key : bytes The key used to initialize an HMAC signature. If unset, messages will not be signed or checked. keyfile : filepath The file containing a key. If this is set, `key` will be initialized to the contents of the file. """ debug = Bool(False, config=True, help="""Debug output in the Session""") packer = DottedObjectName( 'json', config=True, help="""The name of the packer for serializing messages. Should be one of 'json', 'pickle', or an import name for a custom callable serializer.""") def _packer_changed(self, name, old, new): if new.lower() == 'json': self.pack = json_packer self.unpack = json_unpacker elif new.lower() == 'pickle': self.pack = pickle_packer self.unpack = pickle_unpacker else: self.pack = import_item(str(new)) unpacker = DottedObjectName( 'json', config=True, help="""The name of the unpacker for unserializing messages. Only used with custom functions for `packer`.""") def _unpacker_changed(self, name, old, new): if new.lower() == 'json': self.pack = json_packer self.unpack = json_unpacker elif new.lower() == 'pickle': self.pack = pickle_packer self.unpack = pickle_unpacker else: self.unpack = import_item(str(new)) session = CBytes(b'', config=True, help="""The UUID identifying this session.""") def _session_default(self): return bytes(uuid.uuid4()) username = Unicode( os.environ.get('USER', 'username'), config=True, help="""Username for the Session. Default is your system username.""") # message signature related traits: key = CBytes(b'', config=True, help="""execution key, for extra authentication.""") def _key_changed(self, name, old, new): if new: self.auth = hmac.HMAC(new) else: self.auth = None auth = Instance(hmac.HMAC) digest_history = Set() keyfile = Unicode('', config=True, help="""path to file containing execution key.""") def _keyfile_changed(self, name, old, new): with open(new, 'rb') as f: self.key = f.read().strip() pack = Any(default_packer) # the actual packer function def _pack_changed(self, name, old, new): if not callable(new): raise TypeError("packer must be callable, not %s" % type(new)) unpack = Any(default_unpacker) # the actual packer function def _unpack_changed(self, name, old, new): # unpacker is not checked - it is assumed to be if not callable(new): raise TypeError("unpacker must be callable, not %s" % type(new)) def __init__(self, **kwargs): """create a Session object Parameters ---------- debug : bool whether to trigger extra debugging statements packer/unpacker : str : 'json', 'pickle' or import_string importstrings for methods to serialize message parts. If just 'json' or 'pickle', predefined JSON and pickle packers will be used. Otherwise, the entire importstring must be used. The functions must accept at least valid JSON input, and output *bytes*. For example, to use msgpack: packer = 'msgpack.packb', unpacker='msgpack.unpackb' pack/unpack : callables You can also set the pack/unpack callables for serialization directly. session : bytes the ID of this Session object. The default is to generate a new UUID. username : unicode username added to message headers. The default is to ask the OS. key : bytes The key used to initialize an HMAC signature. If unset, messages will not be signed or checked. keyfile : filepath The file containing a key. If this is set, `key` will be initialized to the contents of the file. """ super(Session, self).__init__(**kwargs) self._check_packers() self.none = self.pack({}) @property def msg_id(self): """always return new uuid""" return str(uuid.uuid4()) def _check_packers(self): """check packers for binary data and datetime support.""" pack = self.pack unpack = self.unpack # check simple serialization msg = dict(a=[1, 'hi']) try: packed = pack(msg) except Exception: raise ValueError("packer could not serialize a simple message") # ensure packed message is bytes if not isinstance(packed, bytes): raise ValueError("message packed to %r, but bytes are required" % type(packed)) # check that unpack is pack's inverse try: unpacked = unpack(packed) except Exception: raise ValueError("unpacker could not handle the packer's output") # check datetime support msg = dict(t=datetime.now()) try: unpacked = unpack(pack(msg)) except Exception: self.pack = lambda o: pack(squash_dates(o)) self.unpack = lambda s: extract_dates(unpack(s)) def msg_header(self, msg_type): return msg_header(self.msg_id, msg_type, self.username, self.session) def msg(self, msg_type, content=None, parent=None, subheader=None): """Return the nested message dict. This format is different from what is sent over the wire. The self.serialize method converts this nested message dict to the wire format, which uses a message list. """ msg = {} msg['header'] = self.msg_header(msg_type) msg['msg_id'] = msg['header']['msg_id'] msg['parent_header'] = {} if parent is None else extract_header(parent) msg['msg_type'] = msg_type msg['content'] = {} if content is None else content sub = {} if subheader is None else subheader msg['header'].update(sub) return msg def sign(self, msg_list): """Sign a message with HMAC digest. If no auth, return b''. Parameters ---------- msg_list : list The [p_header,p_parent,p_content] part of the message list. """ if self.auth is None: return b'' h = self.auth.copy() for m in msg_list: h.update(m) return h.hexdigest() def serialize(self, msg, ident=None): """Serialize the message components to bytes. Parameters ---------- msg : dict or Message The nexted message dict as returned by the self.msg method. Returns ------- msg_list : list The list of bytes objects to be sent with the format: [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content, buffer1,buffer2,...]. In this list, the p_* entities are the packed or serialized versions, so if JSON is used, these are uft8 encoded JSON strings. """ content = msg.get('content', {}) if content is None: content = self.none elif isinstance(content, dict): content = self.pack(content) elif isinstance(content, bytes): # content is already packed, as in a relayed message pass elif isinstance(content, unicode): # should be bytes, but JSON often spits out unicode content = content.encode('utf8') else: raise TypeError("Content incorrect type: %s" % type(content)) real_message = [ self.pack(msg['header']), self.pack(msg['parent_header']), content ] to_send = [] if isinstance(ident, list): # accept list of idents to_send.extend(ident) elif ident is not None: to_send.append(ident) to_send.append(DELIM) signature = self.sign(real_message) to_send.append(signature) to_send.extend(real_message) return to_send def send(self, stream, msg_or_type, content=None, parent=None, ident=None, buffers=None, subheader=None, track=False): """Build and send a message via stream or socket. The message format used by this function internally is as follows: [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content, buffer1,buffer2,...] The self.serialize method converts the nested message dict into this format. Parameters ---------- stream : zmq.Socket or ZMQStream the socket-like object used to send the data msg_or_type : str or Message/dict Normally, msg_or_type will be a msg_type unless a message is being sent more than once. content : dict or None the content of the message (ignored if msg_or_type is a message) parent : Message or dict or None the parent or parent header describing the parent of this message ident : bytes or list of bytes the zmq.IDENTITY routing path subheader : dict or None extra header keys for this message's header buffers : list or None the already-serialized buffers to be appended to the message track : bool whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages. Returns ------- msg : message dict the constructed message (msg,tracker) : (message dict, MessageTracker) if track=True, then a 2-tuple will be returned, the first element being the constructed message, and the second being the MessageTracker """ if not isinstance(stream, (zmq.Socket, ZMQStream)): raise TypeError("stream must be Socket or ZMQStream, not %r" % type(stream)) elif track and isinstance(stream, ZMQStream): raise TypeError("ZMQStream cannot track messages") if isinstance(msg_or_type, (Message, dict)): # we got a Message, not a msg_type # don't build a new Message msg = msg_or_type else: msg = self.msg(msg_or_type, content, parent, subheader) buffers = [] if buffers is None else buffers to_send = self.serialize(msg, ident) flag = 0 if buffers: flag = zmq.SNDMORE _track = False else: _track = track if track: tracker = stream.send_multipart(to_send, flag, copy=False, track=_track) else: tracker = stream.send_multipart(to_send, flag, copy=False) for b in buffers[:-1]: stream.send(b, flag, copy=False) if buffers: if track: tracker = stream.send(buffers[-1], copy=False, track=track) else: tracker = stream.send(buffers[-1], copy=False) # omsg = Message(msg) if self.debug: pprint.pprint(msg) pprint.pprint(to_send) pprint.pprint(buffers) msg['tracker'] = tracker return msg def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None): """Send a raw message via ident path. This method is used to send a already serialized message. Parameters ---------- stream : ZMQStream or Socket The ZMQ stream or socket to use for sending the message. msg_list : list The serialized list of messages to send. This only includes the [p_header,p_parent,p_content,buffer1,buffer2,...] portion of the message. ident : ident or list A single ident or a list of idents to use in sending. """ to_send = [] if isinstance(ident, bytes): ident = [ident] if ident is not None: to_send.extend(ident) to_send.append(DELIM) to_send.append(self.sign(msg_list)) to_send.extend(msg_list) stream.send_multipart(msg_list, flags, copy=copy) def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True): """Receive and unpack a message. Parameters ---------- socket : ZMQStream or Socket The socket or stream to use in receiving. Returns ------- [idents], msg [idents] is a list of idents and msg is a nested message dict of same format as self.msg returns. """ if isinstance(socket, ZMQStream): socket = socket.socket try: msg_list = socket.recv_multipart(mode) except zmq.ZMQError as e: if e.errno == zmq.EAGAIN: # We can convert EAGAIN to None as we know in this case # recv_multipart won't return None. return None, None else: raise # split multipart message into identity list and message dict # invalid large messages can cause very expensive string comparisons idents, msg_list = self.feed_identities(msg_list, copy) try: return idents, self.unpack_message(msg_list, content=content, copy=copy) except Exception as e: print(idents, msg_list) # TODO: handle it raise e def feed_identities(self, msg_list, copy=True): """Split the identities from the rest of the message. Feed until DELIM is reached, then return the prefix as idents and remainder as msg_list. This is easily broken by setting an IDENT to DELIM, but that would be silly. Parameters ---------- msg_list : a list of Message or bytes objects The message to be split. copy : bool flag determining whether the arguments are bytes or Messages Returns ------- (idents,msg_list) : two lists idents will always be a list of bytes - the indentity prefix msg_list will be a list of bytes or Messages, unchanged from input msg_list should be unpackable via self.unpack_message at this point. """ if copy: idx = msg_list.index(DELIM) return msg_list[:idx], msg_list[idx + 1:] else: failed = True for idx, m in enumerate(msg_list): if m.bytes == DELIM: failed = False break if failed: raise ValueError("DELIM not in msg_list") idents, msg_list = msg_list[:idx], msg_list[idx + 1:] return [m.bytes for m in idents], msg_list def unpack_message(self, msg_list, content=True, copy=True): """Return a message object from the format sent by self.send. Parameters: ----------- content : bool (True) whether to unpack the content dict (True), or leave it serialized (False) copy : bool (True) whether to return the bytes (True), or the non-copying Message object in each place (False) """ minlen = 4 message = {} if not copy: for i in range(minlen): msg_list[i] = msg_list[i].bytes if self.auth is not None: signature = msg_list[0] if signature in self.digest_history: raise ValueError("Duplicate Signature: %r" % signature) self.digest_history.add(signature) check = self.sign(msg_list[1:4]) if not signature == check: raise ValueError("Invalid Signature: %r" % signature) if not len(msg_list) >= minlen: raise TypeError( "malformed message, must have at least %i elements" % minlen) message['header'] = self.unpack(msg_list[1]) message['msg_type'] = message['header']['msg_type'] message['parent_header'] = self.unpack(msg_list[2]) if content: message['content'] = self.unpack(msg_list[3]) else: message['content'] = msg_list[3] message['buffers'] = msg_list[4:] return message
class View(HasTraits): """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes. Don't use this class, use subclasses. Methods ------- spin flushes incoming results and registration state changes control methods spin, and requesting `ids` also ensures up to date wait wait on one or more msg_ids execution methods apply legacy: execute, run data movement push, pull, scatter, gather query methods get_result, queue_status, purge_results, result_status control methods abort, shutdown """ # flags block = Bool(False) track = Bool(True) targets = Any() history = List() outstanding = Set() results = Dict() client = Instance('IPython.parallel.Client') _socket = Instance('zmq.Socket') _flag_names = List(['targets', 'block', 'track']) _targets = Any() _idents = Any() def __init__(self, client=None, socket=None, **flags): super(View, self).__init__(client=client, _socket=socket) self.results = client.results self.block = client.block self.set_flags(**flags) assert not self.__class__ is View, "Don't use base View objects, use subclasses" def __repr__(self): strtargets = str(self.targets) if len(strtargets) > 16: strtargets = strtargets[:12] + '...]' return "<%s %s>" % (self.__class__.__name__, strtargets) def __len__(self): if isinstance(self.targets, list): return len(self.targets) elif isinstance(self.targets, int): return 1 else: return len(self.client) def set_flags(self, **kwargs): """set my attribute flags by keyword. Views determine behavior with a few attributes (`block`, `track`, etc.). These attributes can be set all at once by name with this method. Parameters ---------- block : bool whether to wait for results track : bool whether to create a MessageTracker to allow the user to safely edit after arrays and buffers during non-copying sends. """ for name, value in kwargs.iteritems(): if name not in self._flag_names: raise KeyError("Invalid name: %r" % name) else: setattr(self, name, value) @contextmanager def temp_flags(self, **kwargs): """temporarily set flags, for use in `with` statements. See set_flags for permanent setting of flags Examples -------- >>> view.track=False ... >>> with view.temp_flags(track=True): ... ar = view.apply(dostuff, my_big_array) ... ar.tracker.wait() # wait for send to finish >>> view.track False """ # preflight: save flags, and set temporaries saved_flags = {} for f in self._flag_names: saved_flags[f] = getattr(self, f) self.set_flags(**kwargs) # yield to the with-statement block try: yield finally: # postflight: restore saved flags self.set_flags(**saved_flags) #---------------------------------------------------------------- # apply #---------------------------------------------------------------- @sync_results @save_ids def _really_apply(self, f, args, kwargs, block=None, **options): """wrapper for client.send_apply_request""" raise NotImplementedError("Implement in subclasses") def apply(self, f, *args, **kwargs): """calls f(*args, **kwargs) on remote engines, returning the result. This method sets all apply flags via this View's attributes. if self.block is False: returns AsyncResult else: returns actual result of f(*args, **kwargs) """ return self._really_apply(f, args, kwargs) def apply_async(self, f, *args, **kwargs): """calls f(*args, **kwargs) on remote engines in a nonblocking manner. returns AsyncResult """ return self._really_apply(f, args, kwargs, block=False) @spin_after def apply_sync(self, f, *args, **kwargs): """calls f(*args, **kwargs) on remote engines in a blocking manner, returning the result. returns: actual result of f(*args, **kwargs) """ return self._really_apply(f, args, kwargs, block=True) #---------------------------------------------------------------- # wrappers for client and control methods #---------------------------------------------------------------- @sync_results def spin(self): """spin the client, and sync""" self.client.spin() @sync_results def wait(self, jobs=None, timeout=-1): """waits on one or more `jobs`, for up to `timeout` seconds. Parameters ---------- jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects ints are indices to self.history strs are msg_ids default: wait on all outstanding messages timeout : float a time in seconds, after which to give up. default is -1, which means no timeout Returns ------- True : when all msg_ids are done False : timeout reached, some msg_ids still outstanding """ if jobs is None: jobs = self.history return self.client.wait(jobs, timeout) def abort(self, jobs=None, targets=None, block=None): """Abort jobs on my engines. Parameters ---------- jobs : None, str, list of strs, optional if None: abort all jobs. else: abort specific msg_id(s). """ block = block if block is not None else self.block targets = targets if targets is not None else self.targets jobs = jobs if jobs is not None else list(self.outstanding) return self.client.abort(jobs=jobs, targets=targets, block=block) def queue_status(self, targets=None, verbose=False): """Fetch the Queue status of my engines""" targets = targets if targets is not None else self.targets return self.client.queue_status(targets=targets, verbose=verbose) def purge_results(self, jobs=[], targets=[]): """Instruct the controller to forget specific results.""" if targets is None or targets == 'all': targets = self.targets return self.client.purge_results(jobs=jobs, targets=targets) def shutdown(self, targets=None, restart=False, hub=False, block=None): """Terminates one or more engine processes, optionally including the hub. """ block = self.block if block is None else block if targets is None or targets == 'all': targets = self.targets return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block) @spin_after def get_result(self, indices_or_msg_ids=None): """return one or more results, specified by history index or msg_id. See client.get_result for details. """ if indices_or_msg_ids is None: indices_or_msg_ids = -1 if isinstance(indices_or_msg_ids, int): indices_or_msg_ids = self.history[indices_or_msg_ids] elif isinstance(indices_or_msg_ids, (list, tuple, set)): indices_or_msg_ids = list(indices_or_msg_ids) for i, index in enumerate(indices_or_msg_ids): if isinstance(index, int): indices_or_msg_ids[i] = self.history[index] return self.client.get_result(indices_or_msg_ids) #------------------------------------------------------------------- # Map #------------------------------------------------------------------- def map(self, f, *sequences, **kwargs): """override in subclasses""" raise NotImplementedError def map_async(self, f, *sequences, **kwargs): """Parallel version of builtin `map`, using this view's engines. This is equivalent to map(...block=False) See `self.map` for details. """ if 'block' in kwargs: raise TypeError( "map_async doesn't take a `block` keyword argument.") kwargs['block'] = False return self.map(f, *sequences, **kwargs) def map_sync(self, f, *sequences, **kwargs): """Parallel version of builtin `map`, using this view's engines. This is equivalent to map(...block=True) See `self.map` for details. """ if 'block' in kwargs: raise TypeError( "map_sync doesn't take a `block` keyword argument.") kwargs['block'] = True return self.map(f, *sequences, **kwargs) def imap(self, f, *sequences, **kwargs): """Parallel version of `itertools.imap`. See `self.map` for details. """ return iter(self.map_async(f, *sequences, **kwargs)) #------------------------------------------------------------------- # Decorators #------------------------------------------------------------------- def remote(self, block=True, **flags): """Decorator for making a RemoteFunction""" block = self.block if block is None else block return remote(self, block=block, **flags) def parallel(self, dist='b', block=None, **flags): """Decorator for making a ParallelFunction""" block = self.block if block is None else block return parallel(self, dist=dist, block=block, **flags)
class Widget(LoggingConfigurable): #------------------------------------------------------------------------- # Class attributes #------------------------------------------------------------------------- _widget_construction_callback = None widgets = {} widget_types = {} @staticmethod def on_widget_constructed(callback): """Registers a callback to be called when a widget is constructed. The callback must have the following signature: callback(widget)""" Widget._widget_construction_callback = callback @staticmethod def _call_widget_constructed(widget): """Static method, called when a widget is constructed.""" if Widget._widget_construction_callback is not None and callable(Widget._widget_construction_callback): Widget._widget_construction_callback(widget) @staticmethod def handle_comm_opened(comm, msg): """Static method, called when a widget is constructed.""" widget_class = import_item(msg['content']['data']['widget_class']) widget = widget_class(comm=comm) #------------------------------------------------------------------------- # Traits #------------------------------------------------------------------------- _model_module = Unicode(None, allow_none=True, help="""A requirejs module name in which to find _model_name. If empty, look in the global registry.""") _model_name = Unicode('WidgetModel', help="""Name of the backbone model registered in the front-end to create and sync this widget with.""") _view_module = Unicode(help="""A requirejs module in which to find _view_name. If empty, look in the global registry.""", sync=True) _view_name = Unicode(None, allow_none=True, help="""Default view registered in the front-end to use to represent the widget.""", sync=True) comm = Instance('IPython.kernel.comm.Comm') msg_throttle = Int(3, sync=True, help="""Maximum number of msgs the front-end can send before receiving an idle msg from the back-end.""") version = Int(0, sync=True, help="""Widget's version""") keys = List() def _keys_default(self): return [name for name in self.traits(sync=True)] _property_lock = Tuple((None, None)) _send_state_lock = Int(0) _states_to_send = Set(allow_none=False) _display_callbacks = Instance(CallbackDispatcher, ()) _msg_callbacks = Instance(CallbackDispatcher, ()) #------------------------------------------------------------------------- # (Con/de)structor #------------------------------------------------------------------------- def __init__(self, **kwargs): """Public constructor""" self._model_id = kwargs.pop('model_id', None) super(Widget, self).__init__(**kwargs) Widget._call_widget_constructed(self) self.open() def __del__(self): """Object disposal""" self.close() #------------------------------------------------------------------------- # Properties #------------------------------------------------------------------------- def open(self): """Open a comm to the frontend if one isn't already open.""" if self.comm is None: args = dict(target_name='ipython.widget', data={'model_name': self._model_name, 'model_module': self._model_module}) if self._model_id is not None: args['comm_id'] = self._model_id self.comm = Comm(**args) def _comm_changed(self, name, new): """Called when the comm is changed.""" if new is None: return self._model_id = self.model_id self.comm.on_msg(self._handle_msg) Widget.widgets[self.model_id] = self # first update self.send_state() @property def model_id(self): """Gets the model id of this widget. If a Comm doesn't exist yet, a Comm will be created automagically.""" return self.comm.comm_id #------------------------------------------------------------------------- # Methods #------------------------------------------------------------------------- def close(self): """Close method. Closes the underlying comm. When the comm is closed, all of the widget views are automatically removed from the front-end.""" if self.comm is not None: Widget.widgets.pop(self.model_id, None) self.comm.close() self.comm = None def send_state(self, key=None): """Sends the widget state, or a piece of it, to the front-end. Parameters ---------- key : unicode, or iterable (optional) A single property's name or iterable of property names to sync with the front-end. """ self._send({ "method" : "update", "state" : self.get_state(key=key) }) def get_state(self, key=None): """Gets the widget state, or a piece of it. Parameters ---------- key : unicode or iterable (optional) A single property's name or iterable of property names to get. """ if key is None: keys = self.keys elif isinstance(key, string_types): keys = [key] elif isinstance(key, collections.Iterable): keys = key else: raise ValueError("key must be a string, an iterable of keys, or None") state = {} for k in keys: f = self.trait_metadata(k, 'to_json', self._trait_to_json) value = getattr(self, k) state[k] = f(value) return state def set_state(self, sync_data): """Called when a state is received from the front-end.""" for name in self.keys: if name in sync_data: json_value = sync_data[name] from_json = self.trait_metadata(name, 'from_json', self._trait_from_json) with self._lock_property(name, json_value): setattr(self, name, from_json(json_value)) def send(self, content): """Sends a custom msg to the widget model in the front-end. Parameters ---------- content : dict Content of the message to send. """ self._send({"method": "custom", "content": content}) def on_msg(self, callback, remove=False): """(Un)Register a custom msg receive callback. Parameters ---------- callback: callable callback will be passed two arguments when a message arrives:: callback(widget, content) remove: bool True if the callback should be unregistered.""" self._msg_callbacks.register_callback(callback, remove=remove) def on_displayed(self, callback, remove=False): """(Un)Register a widget displayed callback. Parameters ---------- callback: method handler Must have a signature of:: callback(widget, **kwargs) kwargs from display are passed through without modification. remove: bool True if the callback should be unregistered.""" self._display_callbacks.register_callback(callback, remove=remove) #------------------------------------------------------------------------- # Support methods #------------------------------------------------------------------------- @contextmanager def _lock_property(self, key, value): """Lock a property-value pair. The value should be the JSON state of the property. NOTE: This, in addition to the single lock for all state changes, is flawed. In the future we may want to look into buffering state changes back to the front-end.""" self._property_lock = (key, value) try: yield finally: self._property_lock = (None, None) @contextmanager def hold_sync(self): """Hold syncing any state until the context manager is released""" # We increment a value so that this can be nested. Syncing will happen when # all levels have been released. self._send_state_lock += 1 try: yield finally: self._send_state_lock -=1 if self._send_state_lock == 0: self.send_state(self._states_to_send) self._states_to_send.clear() def _should_send_property(self, key, value): """Check the property lock (property_lock)""" to_json = self.trait_metadata(key, 'to_json', self._trait_to_json) if (key == self._property_lock[0] and to_json(value) == self._property_lock[1]): return False elif self._send_state_lock > 0: self._states_to_send.add(key) return False else: return True # Event handlers @_show_traceback def _handle_msg(self, msg): """Called when a msg is received from the front-end""" data = msg['content']['data'] method = data['method'] # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one. if method == 'backbone': if 'sync_data' in data: sync_data = data['sync_data'] self.set_state(sync_data) # handles all methods # Handle a state request. elif method == 'request_state': self.send_state() # Handle a custom msg from the front-end. elif method == 'custom': if 'content' in data: self._handle_custom_msg(data['content']) # Catch remainder. else: self.log.error('Unknown front-end to back-end widget msg with method "%s"' % method) def _handle_custom_msg(self, content): """Called when a custom msg is received.""" self._msg_callbacks(self, content) def _notify_trait(self, name, old_value, new_value): """Called when a property has been changed.""" # Trigger default traitlet callback machinery. This allows any user # registered validation to be processed prior to allowing the widget # machinery to handle the state. LoggingConfigurable._notify_trait(self, name, old_value, new_value) # Send the state after the user registered callbacks for trait changes # have all fired (allows for user to validate values). if self.comm is not None and name in self.keys: # Make sure this isn't information that the front-end just sent us. if self._should_send_property(name, new_value): # Send new state to front-end self.send_state(key=name) def _handle_displayed(self, **kwargs): """Called when a view has been displayed for this widget instance""" self._display_callbacks(self, **kwargs) def _trait_to_json(self, x): """Convert a trait value to json Traverse lists/tuples and dicts and serialize their values as well. Replace any widgets with their model_id """ if isinstance(x, dict): return {k: self._trait_to_json(v) for k, v in x.items()} elif isinstance(x, (list, tuple)): return [self._trait_to_json(v) for v in x] elif isinstance(x, Widget): return "IPY_MODEL_" + x.model_id else: return x # Value must be JSON-able def _trait_from_json(self, x): """Convert json values to objects Replace any strings representing valid model id values to Widget references. """ if isinstance(x, dict): return {k: self._trait_from_json(v) for k, v in x.items()} elif isinstance(x, (list, tuple)): return [self._trait_from_json(v) for v in x] elif isinstance(x, string_types) and x.startswith('IPY_MODEL_') and x[10:] in Widget.widgets: # we want to support having child widgets at any level in a hierarchy # trusting that a widget UUID will not appear out in the wild return Widget.widgets[x[10:]] else: return x def _ipython_display_(self, **kwargs): """Called when `IPython.display.display` is called on the widget.""" # Show view. if self._view_name is not None: self._send({"method": "display"}) self._handle_displayed(**kwargs) def _send(self, msg): """Sends a message to the model in the front-end.""" self.comm.send(msg)
class TaskScheduler(SessionFactory): """Python TaskScheduler object. This is the simplest object that supports msg_id based DAG dependencies. *Only* task msg_ids are checked, not msg_ids of jobs submitted via the MUX queue. """ # input arguments: scheme = Instance( FunctionType, default=leastload) # function for determining the destination client_stream = Instance(zmqstream.ZMQStream) # client-facing stream engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream # internals: graph = Dict() # dict by msg_id of [ msg_ids that depend on key ] depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow) pending = Dict() # dict by engine_uuid of submitted tasks completed = Dict() # dict by engine_uuid of completed tasks failed = Dict() # dict by engine_uuid of failed tasks destinations = Dict( ) # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed) clients = Dict() # dict by msg_id for who submitted the task targets = List() # list of target IDENTs loads = List() # list of engine loads all_completed = Set() # set of all completed tasks all_failed = Set() # set of all failed tasks all_done = Set() # set of all finished tasks=union(completed,failed) all_ids = Set() # set of all submitted task IDs blacklist = Dict( ) # dict by msg_id of locations where a job has encountered UnmetDependency auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback') def start(self): self.engine_stream.on_recv(self.dispatch_result, copy=False) self._notification_handlers = dict( registration_notification=self._register_engine, unregistration_notification=self._unregister_engine) self.notifier_stream.on_recv(self.dispatch_notification) self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz self.auditor.start() self.log.info("Scheduler started...%r" % self) def resume_receiving(self): """Resume accepting jobs.""" self.client_stream.on_recv(self.dispatch_submission, copy=False) def stop_receiving(self): """Stop accepting jobs while there are no engines. Leave them in the ZMQ queue.""" self.client_stream.on_recv(None) #----------------------------------------------------------------------- # [Un]Registration Handling #----------------------------------------------------------------------- def dispatch_notification(self, msg): """dispatch register/unregister events.""" idents, msg = self.session.feed_identities(msg) msg = self.session.unpack_message(msg) msg_type = msg['msg_type'] handler = self._notification_handlers.get(msg_type, None) if handler is None: raise Exception("Unhandled message type: %s" % msg_type) else: try: handler(str(msg['content']['queue'])) except KeyError: self.log.error("task::Invalid notification msg: %s" % msg) @logged def _register_engine(self, uid): """New engine with ident `uid` became available.""" # head of the line: self.targets.insert(0, uid) self.loads.insert(0, 0) # initialize sets self.completed[uid] = set() self.failed[uid] = set() self.pending[uid] = {} if len(self.targets) == 1: self.resume_receiving() def _unregister_engine(self, uid): """Existing engine with ident `uid` became unavailable.""" if len(self.targets) == 1: # this was our only engine self.stop_receiving() # handle any potentially finished tasks: self.engine_stream.flush() self.completed.pop(uid) self.failed.pop(uid) # don't pop destinations, because it might be used later # map(self.destinations.pop, self.completed.pop(uid)) # map(self.destinations.pop, self.failed.pop(uid)) idx = self.targets.index(uid) self.targets.pop(idx) self.loads.pop(idx) # wait 5 seconds before cleaning up pending jobs, since the results might # still be incoming if self.pending[uid]: dc = ioloop.DelayedCallback( lambda: self.handle_stranded_tasks(uid), 5000, self.loop) dc.start() @logged def handle_stranded_tasks(self, engine): """Deal with jobs resident in an engine that died.""" lost = self.pending.pop(engine) for msg_id, (raw_msg, targets, MET, follow, timeout) in lost.iteritems(): self.all_failed.add(msg_id) self.all_done.add(msg_id) idents, msg = self.session.feed_identities(raw_msg, copy=False) msg = self.session.unpack_message(msg, copy=False, content=False) parent = msg['header'] idents = [idents[0], engine] + idents[1:] # print (idents) try: raise error.EngineError( "Engine %r died while running task %r" % (engine, msg_id)) except: content = error.wrap_exception() msg = self.session.send(self.client_stream, 'apply_reply', content, parent=parent, ident=idents) self.session.send(self.mon_stream, msg, ident=['outtask'] + idents) self.update_graph(msg_id) #----------------------------------------------------------------------- # Job Submission #----------------------------------------------------------------------- @logged def dispatch_submission(self, raw_msg): """Dispatch job submission to appropriate handlers.""" # ensure targets up to date: self.notifier_stream.flush() try: idents, msg = self.session.feed_identities(raw_msg, copy=False) msg = self.session.unpack_message(msg, content=False, copy=False) except: self.log.error("task::Invaid task: %s" % raw_msg, exc_info=True) return # send to monitor self.mon_stream.send_multipart(['intask'] + raw_msg, copy=False) header = msg['header'] msg_id = header['msg_id'] self.all_ids.add(msg_id) # targets targets = set(header.get('targets', [])) # time dependencies after = Dependency(header.get('after', [])) if after.all: if after.success: after.difference_update(self.all_completed) if after.failure: after.difference_update(self.all_failed) if after.check(self.all_completed, self.all_failed): # recast as empty set, if `after` already met, # to prevent unnecessary set comparisons after = MET # location dependencies follow = Dependency(header.get('follow', [])) # turn timeouts into datetime objects: timeout = header.get('timeout', None) if timeout: timeout = datetime.now() + timedelta(0, timeout, 0) args = [raw_msg, targets, after, follow, timeout] # validate and reduce dependencies: for dep in after, follow: # check valid: if msg_id in dep or dep.difference(self.all_ids): self.depending[msg_id] = args return self.fail_unreachable(msg_id, error.InvalidDependency) # check if unreachable: if dep.unreachable(self.all_completed, self.all_failed): self.depending[msg_id] = args return self.fail_unreachable(msg_id) if after.check(self.all_completed, self.all_failed): # time deps already met, try to run if not self.maybe_run(msg_id, *args): # can't run yet self.save_unmet(msg_id, *args) else: self.save_unmet(msg_id, *args) # @logged def audit_timeouts(self): """Audit all waiting tasks for expired timeouts.""" now = datetime.now() for msg_id in self.depending.keys(): # must recheck, in case one failure cascaded to another: if msg_id in self.depending: raw, after, targets, follow, timeout = self.depending[msg_id] if timeout and timeout < now: self.fail_unreachable(msg_id, timeout=True) @logged def fail_unreachable(self, msg_id, why=error.ImpossibleDependency): """a task has become unreachable, send a reply with an ImpossibleDependency error.""" if msg_id not in self.depending: self.log.error("msg %r already failed!" % msg_id) return raw_msg, targets, after, follow, timeout = self.depending.pop(msg_id) for mid in follow.union(after): if mid in self.graph: self.graph[mid].remove(msg_id) # FIXME: unpacking a message I've already unpacked, but didn't save: idents, msg = self.session.feed_identities(raw_msg, copy=False) msg = self.session.unpack_message(msg, copy=False, content=False) header = msg['header'] try: raise why() except: content = error.wrap_exception() self.all_done.add(msg_id) self.all_failed.add(msg_id) msg = self.session.send(self.client_stream, 'apply_reply', content, parent=header, ident=idents) self.session.send(self.mon_stream, msg, ident=['outtask'] + idents) self.update_graph(msg_id, success=False) @logged def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout): """check location dependencies, and run if they are met.""" blacklist = self.blacklist.setdefault(msg_id, set()) if follow or targets or blacklist: # we need a can_run filter def can_run(idx): target = self.targets[idx] # check targets if targets and target not in targets: return False # check blacklist if target in blacklist: return False # check follow return follow.check(self.completed[target], self.failed[target]) indices = filter(can_run, range(len(self.targets))) if not indices: # couldn't run if follow.all: # check follow for impossibility dests = set() relevant = set() if follow.success: relevant = self.all_completed if follow.failure: relevant = relevant.union(self.all_failed) for m in follow.intersection(relevant): dests.add(self.destinations[m]) if len(dests) > 1: self.fail_unreachable(msg_id) return False if targets: # check blacklist+targets for impossibility targets.difference_update(blacklist) if not targets or not targets.intersection(self.targets): self.fail_unreachable(msg_id) return False return False else: indices = None self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices) return True @logged def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout): """Save a message for later submission when its dependencies are met.""" self.depending[msg_id] = [raw_msg, targets, after, follow, timeout] # track the ids in follow or after, but not those already finished for dep_id in after.union(follow).difference(self.all_done): if dep_id not in self.graph: self.graph[dep_id] = set() self.graph[dep_id].add(msg_id) @logged def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None): """Submit a task to any of a subset of our targets.""" if indices: loads = [self.loads[i] for i in indices] else: loads = self.loads idx = self.scheme(loads) if indices: idx = indices[idx] target = self.targets[idx] # print (target, map(str, msg[:3])) self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False) self.engine_stream.send_multipart(raw_msg, copy=False) self.add_job(idx) self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout) content = dict(msg_id=msg_id, engine_id=target) self.session.send(self.mon_stream, 'task_destination', content=content, ident=['tracktask', self.session.session]) #----------------------------------------------------------------------- # Result Handling #----------------------------------------------------------------------- @logged def dispatch_result(self, raw_msg): """dispatch method for result replies""" try: idents, msg = self.session.feed_identities(raw_msg, copy=False) msg = self.session.unpack_message(msg, content=False, copy=False) except: self.log.error("task::Invaid result: %s" % raw_msg, exc_info=True) return header = msg['header'] if header.get('dependencies_met', True): success = (header['status'] == 'ok') self.handle_result(idents, msg['parent_header'], raw_msg, success) # send to Hub monitor self.mon_stream.send_multipart(['outtask'] + raw_msg, copy=False) else: self.handle_unmet_dependency(idents, msg['parent_header']) @logged def handle_result(self, idents, parent, raw_msg, success=True): """handle a real task result, either success or failure""" # first, relay result to client engine = idents[0] client = idents[1] # swap_ids for XREP-XREP mirror raw_msg[:2] = [client, engine] # print (map(str, raw_msg[:4])) self.client_stream.send_multipart(raw_msg, copy=False) # now, update our data structures msg_id = parent['msg_id'] self.blacklist.pop(msg_id, None) self.pending[engine].pop(msg_id) if success: self.completed[engine].add(msg_id) self.all_completed.add(msg_id) else: self.failed[engine].add(msg_id) self.all_failed.add(msg_id) self.all_done.add(msg_id) self.destinations[msg_id] = engine self.update_graph(msg_id, success) @logged def handle_unmet_dependency(self, idents, parent): """handle an unmet dependency""" engine = idents[0] msg_id = parent['msg_id'] if msg_id not in self.blacklist: self.blacklist[msg_id] = set() self.blacklist[msg_id].add(engine) args = self.pending[engine].pop(msg_id) raw, targets, after, follow, timeout = args if self.blacklist[msg_id] == targets: self.depending[msg_id] = args return self.fail_unreachable(msg_id) elif not self.maybe_run(msg_id, *args): # resubmit failed, put it back in our dependency tree self.save_unmet(msg_id, *args) @logged def update_graph(self, dep_id, success=True): """dep_id just finished. Update our dependency graph and submit any jobs that just became runable.""" # print ("\n\n***********") # pprint (dep_id) # pprint (self.graph) # pprint (self.depending) # pprint (self.all_completed) # pprint (self.all_failed) # print ("\n\n***********\n\n") if dep_id not in self.graph: return jobs = self.graph.pop(dep_id) for msg_id in jobs: raw_msg, targets, after, follow, timeout = self.depending[msg_id] if after.unreachable(self.all_completed, self.all_failed) or follow.unreachable( self.all_completed, self.all_failed): self.fail_unreachable(msg_id) elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout): self.depending.pop(msg_id) for mid in follow.union(after): if mid in self.graph: self.graph[mid].remove(msg_id) #---------------------------------------------------------------------- # methods to be overridden by subclasses #---------------------------------------------------------------------- def add_job(self, idx): """Called after self.targets[idx] just got the job with header. Override with subclasses. The default ordering is simple LRU. The default loads are the number of outstanding jobs.""" self.loads[idx] += 1 for lis in (self.targets, self.loads): lis.append(lis.pop(idx)) def finish_job(self, idx): """Called after self.targets[idx] just finished a job. Override with subclasses.""" self.loads[idx] -= 1
class Client(HasTraits): """A semi-synchronous client to the IPython ZMQ cluster Parameters ---------- url_or_file : bytes; zmq url or path to ipcontroller-client.json Connection information for the Hub's registration. If a json connector file is given, then likely no further configuration is necessary. [Default: use profile] profile : bytes The name of the Cluster profile to be used to find connector information. [Default: 'default'] context : zmq.Context Pass an existing zmq.Context instance, otherwise the client will create its own. username : bytes set username to be passed to the Session object debug : bool flag for lots of message printing for debug purposes #-------------- ssh related args ---------------- # These are args for configuring the ssh tunnel to be used # credentials are used to forward connections over ssh to the Controller # Note that the ip given in `addr` needs to be relative to sshserver # The most basic case is to leave addr as pointing to localhost (127.0.0.1), # and set sshserver as the same machine the Controller is on. However, # the only requirement is that sshserver is able to see the Controller # (i.e. is within the same trusted network). sshserver : str A string of the form passed to ssh, i.e. 'server.tld' or '[email protected]:port' If keyfile or password is specified, and this is not, it will default to the ip given in addr. sshkey : str; path to public ssh key file This specifies a key to be used in ssh login, default None. Regular default ssh keys will be used without specifying this argument. password : str Your ssh password to sshserver. Note that if this is left None, you will be prompted for it if passwordless key based login is unavailable. paramiko : bool flag for whether to use paramiko instead of shell ssh for tunneling. [default: True on win32, False else] ------- exec authentication args ------- If even localhost is untrusted, you can have some protection against unauthorized execution by using a key. Messages are still sent as cleartext, so if someone can snoop your loopback traffic this will not help against malicious attacks. exec_key : str an authentication key or file containing a key default: None Attributes ---------- ids : list of int engine IDs requesting the ids attribute always synchronizes the registration state. To request ids without synchronization, use semi-private _ids attributes. history : list of msg_ids a list of msg_ids, keeping track of all the execution messages you have submitted in order. outstanding : set of msg_ids a set of msg_ids that have been submitted, but whose results have not yet been received. results : dict a dict of all our results, keyed by msg_id block : bool determines default behavior when block not specified in execution methods Methods ------- spin flushes incoming results and registration state changes control methods spin, and requesting `ids` also ensures up to date wait wait on one or more msg_ids execution methods apply legacy: execute, run data movement push, pull, scatter, gather query methods queue_status, get_result, purge, result_status control methods abort, shutdown """ block = Bool(False) outstanding = Set() results = Instance('collections.defaultdict', (dict, )) metadata = Instance('collections.defaultdict', (Metadata, )) history = List() debug = Bool(False) profile = CUnicode('default') _outstanding_dict = Instance('collections.defaultdict', (set, )) _ids = List() _connected = Bool(False) _ssh = Bool(False) _context = Instance('zmq.Context') _config = Dict() _engines = Instance(util.ReverseDict, (), {}) # _hub_socket=Instance('zmq.Socket') _query_socket = Instance('zmq.Socket') _control_socket = Instance('zmq.Socket') _iopub_socket = Instance('zmq.Socket') _notification_socket = Instance('zmq.Socket') _mux_socket = Instance('zmq.Socket') _task_socket = Instance('zmq.Socket') _task_scheme = Str() _closed = False _ignored_control_replies = Int(0) _ignored_hub_replies = Int(0) def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None, context=None, username=None, debug=False, exec_key=None, sshserver=None, sshkey=None, password=None, paramiko=None, timeout=10): super(Client, self).__init__(debug=debug, profile=profile) if context is None: context = zmq.Context.instance() self._context = context self._setup_cluster_dir(profile, cluster_dir, ipython_dir) if self._cd is not None: if url_or_file is None: url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json') assert url_or_file is not None, "I can't find enough information to connect to a hub!"\ " Please specify at least one of url_or_file or profile." try: util.validate_url(url_or_file) except AssertionError: if not os.path.exists(url_or_file): if self._cd: url_or_file = os.path.join(self._cd.security_dir, url_or_file) assert os.path.exists( url_or_file ), "Not a valid connection file or url: %r" % url_or_file with open(url_or_file) as f: cfg = json.loads(f.read()) else: cfg = {'url': url_or_file} # sync defaults from args, json: if sshserver: cfg['ssh'] = sshserver if exec_key: cfg['exec_key'] = exec_key exec_key = cfg['exec_key'] sshserver = cfg['ssh'] url = cfg['url'] location = cfg.setdefault('location', None) cfg['url'] = util.disambiguate_url(cfg['url'], location) url = cfg['url'] self._config = cfg self._ssh = bool(sshserver or sshkey or password) if self._ssh and sshserver is None: # default to ssh via localhost sshserver = url.split('://')[1].split(':')[0] if self._ssh and password is None: if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko): password = False else: password = getpass("SSH Password for %s: " % sshserver) ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko) if exec_key is not None and os.path.isfile(exec_key): arg = 'keyfile' else: arg = 'key' key_arg = {arg: exec_key} if username is None: self.session = ss.StreamSession(**key_arg) else: self.session = ss.StreamSession(username, **key_arg) self._query_socket = self._context.socket(zmq.XREQ) self._query_socket.setsockopt(zmq.IDENTITY, self.session.session) if self._ssh: tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs) else: self._query_socket.connect(url) self.session.debug = self.debug self._notification_handlers = { 'registration_notification': self._register_engine, 'unregistration_notification': self._unregister_engine, 'shutdown_notification': lambda msg: self.close(), } self._queue_handlers = { 'execute_reply': self._handle_execute_reply, 'apply_reply': self._handle_apply_reply } self._connect(sshserver, ssh_kwargs, timeout) def __del__(self): """cleanup sockets, but _not_ context.""" self.close() def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir): if ipython_dir is None: ipython_dir = get_ipython_dir() if cluster_dir is not None: try: self._cd = ClusterDir.find_cluster_dir(cluster_dir) return except ClusterDirError: pass elif profile is not None: try: self._cd = ClusterDir.find_cluster_dir_by_profile( ipython_dir, profile) return except ClusterDirError: pass self._cd = None def _update_engines(self, engines): """Update our engines dict and _ids from a dict of the form: {id:uuid}.""" for k, v in engines.iteritems(): eid = int(k) self._engines[eid] = bytes(v) # force not unicode self._ids.append(eid) self._ids = sorted(self._ids) if sorted(self._engines.keys()) != range(len(self._engines)) and \ self._task_scheme == 'pure' and self._task_socket: self._stop_scheduling_tasks() def _stop_scheduling_tasks(self): """Stop scheduling tasks because an engine has been unregistered from a pure ZMQ scheduler. """ self._task_socket.close() self._task_socket = None msg = "An engine has been unregistered, and we are using pure " +\ "ZMQ task scheduling. Task farming will be disabled." if self.outstanding: msg += " If you were running tasks when this happened, " +\ "some `outstanding` msg_ids may never resolve." warnings.warn(msg, RuntimeWarning) def _build_targets(self, targets): """Turn valid target IDs or 'all' into two lists: (int_ids, uuids). """ if targets is None: targets = self._ids elif isinstance(targets, str): if targets.lower() == 'all': targets = self._ids else: raise TypeError("%r not valid str target, must be 'all'" % (targets)) elif isinstance(targets, int): if targets < 0: targets = self.ids[targets] if targets not in self.ids: raise IndexError("No such engine: %i" % targets) targets = [targets] if isinstance(targets, slice): indices = range(len(self._ids))[targets] ids = self.ids targets = [ids[i] for i in indices] if not isinstance(targets, (tuple, list, xrange)): raise TypeError( "targets by int/slice/collection of ints only, not %s" % (type(targets))) return [self._engines[t] for t in targets], list(targets) def _connect(self, sshserver, ssh_kwargs, timeout): """setup all our socket connections to the cluster. This is called from __init__.""" # Maybe allow reconnecting? if self._connected: return self._connected = True def connect_socket(s, url): url = util.disambiguate_url(url, self._config['location']) if self._ssh: return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs) else: return s.connect(url) self.session.send(self._query_socket, 'connection_request') r, w, x = zmq.select([self._query_socket], [], [], timeout) if not r: raise error.TimeoutError("Hub connection request timed out") idents, msg = self.session.recv(self._query_socket, mode=0) if self.debug: pprint(msg) msg = ss.Message(msg) content = msg.content self._config['registration'] = dict(content) if content.status == 'ok': if content.mux: self._mux_socket = self._context.socket(zmq.XREQ) self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session) connect_socket(self._mux_socket, content.mux) if content.task: self._task_scheme, task_addr = content.task self._task_socket = self._context.socket(zmq.XREQ) self._task_socket.setsockopt(zmq.IDENTITY, self.session.session) connect_socket(self._task_socket, task_addr) if content.notification: self._notification_socket = self._context.socket(zmq.SUB) connect_socket(self._notification_socket, content.notification) self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'') # if content.query: # self._query_socket = self._context.socket(zmq.XREQ) # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session) # connect_socket(self._query_socket, content.query) if content.control: self._control_socket = self._context.socket(zmq.XREQ) self._control_socket.setsockopt(zmq.IDENTITY, self.session.session) connect_socket(self._control_socket, content.control) if content.iopub: self._iopub_socket = self._context.socket(zmq.SUB) self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'') self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session) connect_socket(self._iopub_socket, content.iopub) self._update_engines(dict(content.engines)) else: self._connected = False raise Exception("Failed to connect!") #-------------------------------------------------------------------------- # handlers and callbacks for incoming messages #-------------------------------------------------------------------------- def _unwrap_exception(self, content): """unwrap exception, and remap engine_id to int.""" e = error.unwrap_exception(content) # print e.traceback if e.engine_info: e_uuid = e.engine_info['engine_uuid'] eid = self._engines[e_uuid] e.engine_info['engine_id'] = eid return e def _extract_metadata(self, header, parent, content): md = { 'msg_id': parent['msg_id'], 'received': datetime.now(), 'engine_uuid': header.get('engine', None), 'follow': parent.get('follow', []), 'after': parent.get('after', []), 'status': content['status'], } if md['engine_uuid'] is not None: md['engine_id'] = self._engines.get(md['engine_uuid'], None) if 'date' in parent: md['submitted'] = datetime.strptime(parent['date'], util.ISO8601) if 'started' in header: md['started'] = datetime.strptime(header['started'], util.ISO8601) if 'date' in header: md['completed'] = datetime.strptime(header['date'], util.ISO8601) return md def _register_engine(self, msg): """Register a new engine, and update our connection info.""" content = msg['content'] eid = content['id'] d = {eid: content['queue']} self._update_engines(d) def _unregister_engine(self, msg): """Unregister an engine that has died.""" content = msg['content'] eid = int(content['id']) if eid in self._ids: self._ids.remove(eid) uuid = self._engines.pop(eid) self._handle_stranded_msgs(eid, uuid) if self._task_socket and self._task_scheme == 'pure': self._stop_scheduling_tasks() def _handle_stranded_msgs(self, eid, uuid): """Handle messages known to be on an engine when the engine unregisters. It is possible that this will fire prematurely - that is, an engine will go down after completing a result, and the client will be notified of the unregistration and later receive the successful result. """ outstanding = self._outstanding_dict[uuid] for msg_id in list(outstanding): if msg_id in self.results: # we already continue try: raise error.EngineError( "Engine %r died while running task %r" % (eid, msg_id)) except: content = error.wrap_exception() # build a fake message: parent = {} header = {} parent['msg_id'] = msg_id header['engine'] = uuid header['date'] = datetime.now().strftime(util.ISO8601) msg = dict(parent_header=parent, header=header, content=content) self._handle_apply_reply(msg) def _handle_execute_reply(self, msg): """Save the reply to an execute_request into our results. execute messages are never actually used. apply is used instead. """ parent = msg['parent_header'] msg_id = parent['msg_id'] if msg_id not in self.outstanding: if msg_id in self.history: print("got stale result: %s" % msg_id) else: print("got unknown result: %s" % msg_id) else: self.outstanding.remove(msg_id) self.results[msg_id] = self._unwrap_exception(msg['content']) def _handle_apply_reply(self, msg): """Save the reply to an apply_request into our results.""" parent = msg['parent_header'] msg_id = parent['msg_id'] if msg_id not in self.outstanding: if msg_id in self.history: print("got stale result: %s" % msg_id) print self.results[msg_id] print msg else: print("got unknown result: %s" % msg_id) else: self.outstanding.remove(msg_id) content = msg['content'] header = msg['header'] # construct metadata: md = self.metadata[msg_id] md.update(self._extract_metadata(header, parent, content)) # is this redundant? self.metadata[msg_id] = md e_outstanding = self._outstanding_dict[md['engine_uuid']] if msg_id in e_outstanding: e_outstanding.remove(msg_id) # construct result: if content['status'] == 'ok': self.results[msg_id] = util.unserialize_object(msg['buffers'])[0] elif content['status'] == 'aborted': self.results[msg_id] = error.TaskAborted(msg_id) elif content['status'] == 'resubmitted': # TODO: handle resubmission pass else: self.results[msg_id] = self._unwrap_exception(content) def _flush_notifications(self): """Flush notifications of engine registrations waiting in ZMQ queue.""" msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK) while msg is not None: if self.debug: pprint(msg) msg = msg[-1] msg_type = msg['msg_type'] handler = self._notification_handlers.get(msg_type, None) if handler is None: raise Exception("Unhandled message type: %s" % msg.msg_type) else: handler(msg) msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK) def _flush_results(self, sock): """Flush task or queue results waiting in ZMQ queue.""" msg = self.session.recv(sock, mode=zmq.NOBLOCK) while msg is not None: if self.debug: pprint(msg) msg = msg[-1] msg_type = msg['msg_type'] handler = self._queue_handlers.get(msg_type, None) if handler is None: raise Exception("Unhandled message type: %s" % msg.msg_type) else: handler(msg) msg = self.session.recv(sock, mode=zmq.NOBLOCK) def _flush_control(self, sock): """Flush replies from the control channel waiting in the ZMQ queue. Currently: ignore them.""" if self._ignored_control_replies <= 0: return msg = self.session.recv(sock, mode=zmq.NOBLOCK) while msg is not None: self._ignored_control_replies -= 1 if self.debug: pprint(msg) msg = self.session.recv(sock, mode=zmq.NOBLOCK) def _flush_ignored_control(self): """flush ignored control replies""" while self._ignored_control_replies > 0: self.session.recv(self._control_socket) self._ignored_control_replies -= 1 def _flush_ignored_hub_replies(self): msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK) while msg is not None: msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK) def _flush_iopub(self, sock): """Flush replies from the iopub channel waiting in the ZMQ queue. """ msg = self.session.recv(sock, mode=zmq.NOBLOCK) while msg is not None: if self.debug: pprint(msg) msg = msg[-1] parent = msg['parent_header'] msg_id = parent['msg_id'] content = msg['content'] header = msg['header'] msg_type = msg['msg_type'] # init metadata: md = self.metadata[msg_id] if msg_type == 'stream': name = content['name'] s = md[name] or '' md[name] = s + content['data'] elif msg_type == 'pyerr': md.update({'pyerr': self._unwrap_exception(content)}) elif msg_type == 'pyin': md.update({'pyin': content['code']}) else: md.update({msg_type: content.get('data', '')}) # reduntant? self.metadata[msg_id] = md msg = self.session.recv(sock, mode=zmq.NOBLOCK) #-------------------------------------------------------------------------- # len, getitem #-------------------------------------------------------------------------- def __len__(self): """len(client) returns # of engines.""" return len(self.ids) def __getitem__(self, key): """index access returns DirectView multiplexer objects Must be int, slice, or list/tuple/xrange of ints""" if not isinstance(key, (int, slice, tuple, list, xrange)): raise TypeError("key by int/slice/iterable of ints only, not %s" % (type(key))) else: return self.direct_view(key) #-------------------------------------------------------------------------- # Begin public methods #-------------------------------------------------------------------------- @property def ids(self): """Always up-to-date ids property.""" self._flush_notifications() # always copy: return list(self._ids) def close(self): if self._closed: return snames = filter(lambda n: n.endswith('socket'), dir(self)) for socket in map(lambda name: getattr(self, name), snames): if isinstance(socket, zmq.Socket) and not socket.closed: socket.close() self._closed = True def spin(self): """Flush any registration notifications and execution results waiting in the ZMQ queue. """ if self._notification_socket: self._flush_notifications() if self._mux_socket: self._flush_results(self._mux_socket) if self._task_socket: self._flush_results(self._task_socket) if self._control_socket: self._flush_control(self._control_socket) if self._iopub_socket: self._flush_iopub(self._iopub_socket) if self._query_socket: self._flush_ignored_hub_replies() def wait(self, jobs=None, timeout=-1): """waits on one or more `jobs`, for up to `timeout` seconds. Parameters ---------- jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects ints are indices to self.history strs are msg_ids default: wait on all outstanding messages timeout : float a time in seconds, after which to give up. default is -1, which means no timeout Returns ------- True : when all msg_ids are done False : timeout reached, some msg_ids still outstanding """ tic = time.time() if jobs is None: theids = self.outstanding else: if isinstance(jobs, (int, str, AsyncResult)): jobs = [jobs] theids = set() for job in jobs: if isinstance(job, int): # index access job = self.history[job] elif isinstance(job, AsyncResult): map(theids.add, job.msg_ids) continue theids.add(job) if not theids.intersection(self.outstanding): return True self.spin() while theids.intersection(self.outstanding): if timeout >= 0 and (time.time() - tic) > timeout: break time.sleep(1e-3) self.spin() return len(theids.intersection(self.outstanding)) == 0 #-------------------------------------------------------------------------- # Control methods #-------------------------------------------------------------------------- @spin_first @default_block def clear(self, targets=None, block=None): """Clear the namespace in target(s).""" targets = self._build_targets(targets)[0] for t in targets: self.session.send(self._control_socket, 'clear_request', content={}, ident=t) error = False if self.block: self._flush_ignored_control() for i in range(len(targets)): idents, msg = self.session.recv(self._control_socket, 0) if self.debug: pprint(msg) if msg['content']['status'] != 'ok': error = self._unwrap_exception(msg['content']) else: self._ignored_control_replies += len(targets) if error: raise error @spin_first @default_block def abort(self, jobs=None, targets=None, block=None): """Abort specific jobs from the execution queues of target(s). This is a mechanism to prevent jobs that have already been submitted from executing. Parameters ---------- jobs : msg_id, list of msg_ids, or AsyncResult The jobs to be aborted """ targets = self._build_targets(targets)[0] msg_ids = [] if isinstance(jobs, (basestring, AsyncResult)): jobs = [jobs] bad_ids = filter( lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs) if bad_ids: raise TypeError( "Invalid msg_id type %r, expected str or AsyncResult" % bad_ids[0]) for j in jobs: if isinstance(j, AsyncResult): msg_ids.extend(j.msg_ids) else: msg_ids.append(j) content = dict(msg_ids=msg_ids) for t in targets: self.session.send(self._control_socket, 'abort_request', content=content, ident=t) error = False if self.block: self._flush_ignored_control() for i in range(len(targets)): idents, msg = self.session.recv(self._control_socket, 0) if self.debug: pprint(msg) if msg['content']['status'] != 'ok': error = self._unwrap_exception(msg['content']) else: self._ignored_control_replies += len(targets) if error: raise error @spin_first @default_block def shutdown(self, targets=None, restart=False, hub=False, block=None): """Terminates one or more engine processes, optionally including the hub.""" if hub: targets = 'all' targets = self._build_targets(targets)[0] for t in targets: self.session.send(self._control_socket, 'shutdown_request', content={'restart': restart}, ident=t) error = False if block or hub: self._flush_ignored_control() for i in range(len(targets)): idents, msg = self.session.recv(self._control_socket, 0) if self.debug: pprint(msg) if msg['content']['status'] != 'ok': error = self._unwrap_exception(msg['content']) else: self._ignored_control_replies += len(targets) if hub: time.sleep(0.25) self.session.send(self._query_socket, 'shutdown_request') idents, msg = self.session.recv(self._query_socket, 0) if self.debug: pprint(msg) if msg['content']['status'] != 'ok': error = self._unwrap_exception(msg['content']) if error: raise error #-------------------------------------------------------------------------- # Execution methods #-------------------------------------------------------------------------- @default_block def _execute(self, code, targets='all', block=None): """Executes `code` on `targets` in blocking or nonblocking manner. ``execute`` is always `bound` (affects engine namespace) Parameters ---------- code : str the code string to be executed targets : int/str/list of ints/strs the engines on which to execute default : all block : bool whether or not to wait until done to return default: self.block """ return self[targets].execute(code, block=block) def _maybe_raise(self, result): """wrapper for maybe raising an exception if apply failed.""" if isinstance(result, error.RemoteError): raise result return result def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False, ident=None): """construct and send an apply message via a socket. This is the principal method with which all engine execution is performed by views. """ assert not self._closed, "cannot use me anymore, I'm closed!" # defaults: args = args if args is not None else [] kwargs = kwargs if kwargs is not None else {} subheader = subheader if subheader is not None else {} # validate arguments if not callable(f): raise TypeError("f must be callable, not %s" % type(f)) if not isinstance(args, (tuple, list)): raise TypeError("args must be tuple or list, not %s" % type(args)) if not isinstance(kwargs, dict): raise TypeError("kwargs must be dict, not %s" % type(kwargs)) if not isinstance(subheader, dict): raise TypeError("subheader must be dict, not %s" % type(subheader)) if not self._ids: # flush notification socket if no engines yet any_ids = self.ids if not any_ids: raise error.NoEnginesRegistered( "Can't execute without any connected engines.") # enforce types of f,args,kwargs bufs = util.pack_apply_message(f, args, kwargs) msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident, subheader=subheader, track=track) msg_id = msg['msg_id'] self.outstanding.add(msg_id) if ident: # possibly routed to a specific engine if isinstance(ident, list): ident = ident[-1] if ident in self._engines.values(): # save for later, in case of engine death self._outstanding_dict[ident].add(msg_id) self.history.append(msg_id) self.metadata[msg_id]['submitted'] = datetime.now() return msg #-------------------------------------------------------------------------- # construct a View object #-------------------------------------------------------------------------- def load_balanced_view(self, targets=None): """construct a DirectView object. If no arguments are specified, create a LoadBalancedView using all engines. Parameters ---------- targets: list,slice,int,etc. [default: use all engines] The subset of engines across which to load-balance """ if targets is not None: targets = self._build_targets(targets)[1] return LoadBalancedView(client=self, socket=self._task_socket, targets=targets) def direct_view(self, targets='all'): """construct a DirectView object. If no targets are specified, create a DirectView using all engines. Parameters ---------- targets: list,slice,int,etc. [default: use all engines] The engines to use for the View """ single = isinstance(targets, int) targets = self._build_targets(targets)[1] if single: targets = targets[0] return DirectView(client=self, socket=self._mux_socket, targets=targets) #-------------------------------------------------------------------------- # Data movement (TO BE REMOVED) #-------------------------------------------------------------------------- @default_block def _push(self, ns, targets='all', block=None, track=False): """Push the contents of `ns` into the namespace on `target`""" if not isinstance(ns, dict): raise TypeError("Must be a dict, not %s" % type(ns)) result = self.apply(util._push, kwargs=ns, targets=targets, block=block, bound=True, balanced=False, track=track) if not block: return result @default_block def _pull(self, keys, targets='all', block=None): """Pull objects from `target`'s namespace by `keys`""" if isinstance(keys, basestring): pass elif isinstance(keys, (list, tuple, set)): for key in keys: if not isinstance(key, basestring): raise TypeError("keys must be str, not type %r" % type(key)) else: raise TypeError("keys must be strs, not %r" % keys) result = self.apply(util._pull, (keys, ), targets=targets, block=block, bound=True, balanced=False) return result #-------------------------------------------------------------------------- # Query methods #-------------------------------------------------------------------------- @spin_first @default_block def get_result(self, indices_or_msg_ids=None, block=None): """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object. If the client already has the results, no request to the Hub will be made. This is a convenient way to construct AsyncResult objects, which are wrappers that include metadata about execution, and allow for awaiting results that were not submitted by this Client. It can also be a convenient way to retrieve the metadata associated with blocking execution, since it always retrieves Examples -------- :: In [10]: r = client.apply() Parameters ---------- indices_or_msg_ids : integer history index, str msg_id, or list of either The indices or msg_ids of indices to be retrieved block : bool Whether to wait for the result to be done Returns ------- AsyncResult A single AsyncResult object will always be returned. AsyncHubResult A subclass of AsyncResult that retrieves results from the Hub """ if indices_or_msg_ids is None: indices_or_msg_ids = -1 if not isinstance(indices_or_msg_ids, (list, tuple)): indices_or_msg_ids = [indices_or_msg_ids] theids = [] for id in indices_or_msg_ids: if isinstance(id, int): id = self.history[id] if not isinstance(id, str): raise TypeError("indices must be str or int, not %r" % id) theids.append(id) local_ids = filter( lambda msg_id: msg_id in self.history or msg_id in self.results, theids) remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids) if remote_ids: ar = AsyncHubResult(self, msg_ids=theids) else: ar = AsyncResult(self, msg_ids=theids) if block: ar.wait() return ar @spin_first def result_status(self, msg_ids, status_only=True): """Check on the status of the result(s) of the apply request with `msg_ids`. If status_only is False, then the actual results will be retrieved, else only the status of the results will be checked. Parameters ---------- msg_ids : list of msg_ids if int: Passed as index to self.history for convenience. status_only : bool (default: True) if False: Retrieve the actual results of completed tasks. Returns ------- results : dict There will always be the keys 'pending' and 'completed', which will be lists of msg_ids that are incomplete or complete. If `status_only` is False, then completed results will be keyed by their `msg_id`. """ if not isinstance(msg_ids, (list, tuple)): msg_ids = [msg_ids] theids = [] for msg_id in msg_ids: if isinstance(msg_id, int): msg_id = self.history[msg_id] if not isinstance(msg_id, basestring): raise TypeError("msg_ids must be str, not %r" % msg_id) theids.append(msg_id) completed = [] local_results = {} # comment this block out to temporarily disable local shortcut: for msg_id in theids: if msg_id in self.results: completed.append(msg_id) local_results[msg_id] = self.results[msg_id] theids.remove(msg_id) if theids: # some not locally cached content = dict(msg_ids=theids, status_only=status_only) msg = self.session.send(self._query_socket, "result_request", content=content) zmq.select([self._query_socket], [], []) idents, msg = self.session.recv(self._query_socket, zmq.NOBLOCK) if self.debug: pprint(msg) content = msg['content'] if content['status'] != 'ok': raise self._unwrap_exception(content) buffers = msg['buffers'] else: content = dict(completed=[], pending=[]) content['completed'].extend(completed) if status_only: return content failures = [] # load cached results into result: content.update(local_results) # update cache with results: for msg_id in sorted(theids): if msg_id in content['completed']: rec = content[msg_id] parent = rec['header'] header = rec['result_header'] rcontent = rec['result_content'] iodict = rec['io'] if isinstance(rcontent, str): rcontent = self.session.unpack(rcontent) md = self.metadata[msg_id] md.update(self._extract_metadata(header, parent, rcontent)) md.update(iodict) if rcontent['status'] == 'ok': res, buffers = util.unserialize_object(buffers) else: print rcontent res = self._unwrap_exception(rcontent) failures.append(res) self.results[msg_id] = res content[msg_id] = res if len(theids) == 1 and failures: raise failures[0] error.collect_exceptions(failures, "result_status") return content @spin_first def queue_status(self, targets='all', verbose=False): """Fetch the status of engine queues. Parameters ---------- targets : int/str/list of ints/strs the engines whose states are to be queried. default : all verbose : bool Whether to return lengths only, or lists of ids for each element """ engine_ids = self._build_targets(targets)[1] content = dict(targets=engine_ids, verbose=verbose) self.session.send(self._query_socket, "queue_request", content=content) idents, msg = self.session.recv(self._query_socket, 0) if self.debug: pprint(msg) content = msg['content'] status = content.pop('status') if status != 'ok': raise self._unwrap_exception(content) content = util.rekey(content) if isinstance(targets, int): return content[targets] else: return content @spin_first def purge_results(self, jobs=[], targets=[]): """Tell the Hub to forget results. Individual results can be purged by msg_id, or the entire history of specific targets can be purged. Parameters ---------- jobs : str or list of str or AsyncResult objects the msg_ids whose results should be forgotten. targets : int/str/list of ints/strs The targets, by uuid or int_id, whose entire history is to be purged. Use `targets='all'` to scrub everything from the Hub's memory. default : None """ if not targets and not jobs: raise ValueError( "Must specify at least one of `targets` and `jobs`") if targets: targets = self._build_targets(targets)[1] # construct msg_ids from jobs msg_ids = [] if isinstance(jobs, (basestring, AsyncResult)): jobs = [jobs] bad_ids = filter( lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs) if bad_ids: raise TypeError( "Invalid msg_id type %r, expected str or AsyncResult" % bad_ids[0]) for j in jobs: if isinstance(j, AsyncResult): msg_ids.extend(j.msg_ids) else: msg_ids.append(j) content = dict(targets=targets, msg_ids=msg_ids) self.session.send(self._query_socket, "purge_request", content=content) idents, msg = self.session.recv(self._query_socket, 0) if self.debug: pprint(msg) content = msg['content'] if content['status'] != 'ok': raise self._unwrap_exception(content)
class ExtractOutputPreprocessor(Preprocessor): """ Extracts all of the outputs from the notebook file. The extracted outputs are returned in the 'resources' dictionary. """ output_filename_template = Unicode( "{unique_key}_{cell_index}_{index}{extension}", config=True) extract_output_types = Set({'png', 'jpeg', 'svg', 'application/pdf'}, config=True) def preprocess_cell(self, cell, resources, cell_index): """ Apply a transformation on each cell, Parameters ---------- cell : NotebookNode cell Notebook cell being processed resources : dictionary Additional resources used in the conversion process. Allows preprocessors to pass variables into the Jinja engine. cell_index : int Index of the cell being processed (see base.py) """ #Get the unique key from the resource dict if it exists. If it does not #exist, use 'output' as the default. Also, get files directory if it #has been specified unique_key = resources.get('unique_key', 'output') output_files_dir = resources.get('output_files_dir', None) #Make sure outputs key exists if not isinstance(resources['outputs'], dict): resources['outputs'] = {} #Loop through all of the outputs in the cell for index, out in enumerate(cell.get('outputs', [])): #Get the output in data formats that the template needs extracted for out_type in self.extract_output_types: if out_type in out: data = out[out_type] #Binary files are base64-encoded, SVG is already XML if out_type in {'png', 'jpeg', 'application/pdf'}: # data is b64-encoded as text (str, unicode) # decodestring only accepts bytes data = py3compat.cast_bytes(data) data = base64.decodestring(data) elif sys.platform == 'win32': data = data.replace('\n', '\r\n').encode("UTF-8") else: data = data.encode("UTF-8") # Build an output name # filthy hack while we have some mimetype output, and some not if '/' in out_type: ext = guess_extension(out_type) if ext is None: ext = '.' + out_type.rsplit('/')[-1] else: ext = '.' + out_type filename = self.output_filename_template.format( unique_key=unique_key, cell_index=cell_index, index=index, extension=ext) #On the cell, make the figure available via # cell.outputs[i].svg_filename ... etc (svg in example) # Where # cell.outputs[i].svg contains the data if output_files_dir is not None: filename = os.path.join(output_files_dir, filename) out[out_type + '_filename'] = filename #In the resources, make the figure available via # resources['outputs']['filename'] = data resources['outputs'][filename] = data return cell, resources
class Kernel(SessionFactory): #--------------------------------------------------------------------------- # Kernel interface #--------------------------------------------------------------------------- # kwargs: exec_lines = List(Unicode, config=True, help="List of lines to execute") # identities: int_id = Int(-1) bident = CBytes() ident = Unicode() def _ident_changed(self, name, old, new): self.bident = asbytes(new) user_ns = Dict(config=True, help="""Set the user's namespace of the Kernel""") control_stream = Instance(zmqstream.ZMQStream) task_stream = Instance(zmqstream.ZMQStream) iopub_stream = Instance(zmqstream.ZMQStream) client = Instance('IPython.parallel.Client') # internals shell_streams = List() compiler = Instance(CommandCompiler, (), {}) completer = Instance(KernelCompleter) aborted = Set() shell_handlers = Dict() control_handlers = Dict() def _set_prefix(self): self.prefix = "engine.%s" % self.int_id def _connect_completer(self): self.completer = KernelCompleter(self.user_ns) def __init__(self, **kwargs): super(Kernel, self).__init__(**kwargs) self._set_prefix() self._connect_completer() self.on_trait_change(self._set_prefix, 'id') self.on_trait_change(self._connect_completer, 'user_ns') # Build dict of handlers for message types for msg_type in [ 'execute_request', 'complete_request', 'apply_request', 'clear_request' ]: self.shell_handlers[msg_type] = getattr(self, msg_type) for msg_type in ['shutdown_request', 'abort_request' ] + self.shell_handlers.keys(): self.control_handlers[msg_type] = getattr(self, msg_type) self._initial_exec_lines() def _wrap_exception(self, method=None): e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method) content = wrap_exception(e_info) return content def _initial_exec_lines(self): s = _Passer() content = dict(silent=True, user_variable=[], user_expressions=[]) for line in self.exec_lines: self.log.debug("executing initialization: %s" % line) content.update({'code': line}) msg = self.session.msg('execute_request', content) self.execute_request(s, [], msg) #-------------------- control handlers ----------------------------- def abort_queues(self): for stream in self.shell_streams: if stream: self.abort_queue(stream) def abort_queue(self, stream): while True: idents, msg = self.session.recv(stream, zmq.NOBLOCK, content=True) if msg is None: return self.log.info("Aborting:") self.log.info(str(msg)) msg_type = msg['header']['msg_type'] reply_type = msg_type.split('_')[0] + '_reply' # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg) # self.reply_socket.send(ident,zmq.SNDMORE) # self.reply_socket.send_json(reply_msg) reply_msg = self.session.send(stream, reply_type, content={'status': 'aborted'}, parent=msg, ident=idents) self.log.debug(str(reply_msg)) # We need to wait a bit for requests to come in. This can probably # be set shorter for true asynchronous clients. time.sleep(0.05) def abort_request(self, stream, ident, parent): """abort a specifig msg by id""" msg_ids = parent['content'].get('msg_ids', None) if isinstance(msg_ids, basestring): msg_ids = [msg_ids] if not msg_ids: self.abort_queues() for mid in msg_ids: self.aborted.add(str(mid)) content = dict(status='ok') reply_msg = self.session.send(stream, 'abort_reply', content=content, parent=parent, ident=ident) self.log.debug(str(reply_msg)) def shutdown_request(self, stream, ident, parent): """kill ourself. This should really be handled in an external process""" try: self.abort_queues() except: content = self._wrap_exception('shutdown') else: content = dict(parent['content']) content['status'] = 'ok' msg = self.session.send(stream, 'shutdown_reply', content=content, parent=parent, ident=ident) self.log.debug(str(msg)) dc = ioloop.DelayedCallback(lambda: sys.exit(0), 1000, self.loop) dc.start() def dispatch_control(self, msg): idents, msg = self.session.feed_identities(msg, copy=False) try: msg = self.session.unserialize(msg, content=True, copy=False) except: self.log.error("Invalid Message", exc_info=True) return else: self.log.debug("Control received, %s", msg) header = msg['header'] msg_id = header['msg_id'] msg_type = header['msg_type'] handler = self.control_handlers.get(msg_type, None) if handler is None: self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r" % msg_type) else: handler(self.control_stream, idents, msg) #-------------------- queue helpers ------------------------------ def check_dependencies(self, dependencies): if not dependencies: return True if len(dependencies) == 2 and dependencies[0] in 'any all'.split(): anyorall = dependencies[0] dependencies = dependencies[1] else: anyorall = 'all' results = self.client.get_results(dependencies, status_only=True) if results['status'] != 'ok': return False if anyorall == 'any': if not results['completed']: return False else: if results['pending']: return False return True def check_aborted(self, msg_id): return msg_id in self.aborted #-------------------- queue handlers ----------------------------- def clear_request(self, stream, idents, parent): """Clear our namespace.""" self.user_ns = {} msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent, content=dict(status='ok')) self._initial_exec_lines() def execute_request(self, stream, ident, parent): self.log.debug('execute request %s' % parent) try: code = parent[u'content'][u'code'] except: self.log.error("Got bad msg: %s" % parent, exc_info=True) return self.session.send(self.iopub_stream, u'pyin', {u'code': code}, parent=parent, ident=asbytes('%s.pyin' % self.prefix)) started = datetime.now() try: comp_code = self.compiler(code, '<zmq-kernel>') # allow for not overriding displayhook if hasattr(sys.displayhook, 'set_parent'): sys.displayhook.set_parent(parent) sys.stdout.set_parent(parent) sys.stderr.set_parent(parent) exec comp_code in self.user_ns, self.user_ns except: exc_content = self._wrap_exception('execute') # exc_msg = self.session.msg(u'pyerr', exc_content, parent) self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent, ident=asbytes('%s.pyerr' % self.prefix)) reply_content = exc_content else: reply_content = {'status': 'ok'} reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent, ident=ident, subheader=dict(started=started)) self.log.debug(str(reply_msg)) if reply_msg['content']['status'] == u'error': self.abort_queues() def complete_request(self, stream, ident, parent): matches = {'matches': self.complete(parent), 'status': 'ok'} completion_msg = self.session.send(stream, 'complete_reply', matches, parent, ident) # print >> sys.__stdout__, completion_msg def complete(self, msg): return self.completer.complete(msg.content.line, msg.content.text) def apply_request(self, stream, ident, parent): # flush previous reply, so this request won't block it stream.flush(zmq.POLLOUT) try: content = parent[u'content'] bufs = parent[u'buffers'] msg_id = parent['header']['msg_id'] # bound = parent['header'].get('bound', False) except: self.log.error("Got bad msg: %s" % parent, exc_info=True) return # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent) # self.iopub_stream.send(pyin_msg) # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent) sub = { 'dependencies_met': True, 'engine': self.ident, 'started': datetime.now() } try: # allow for not overriding displayhook if hasattr(sys.displayhook, 'set_parent'): sys.displayhook.set_parent(parent) sys.stdout.set_parent(parent) sys.stderr.set_parent(parent) # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns working = self.user_ns # suffix = prefix = "_" + str(msg_id).replace("-", "") + "_" f, args, kwargs = unpack_apply_message(bufs, working, copy=False) # if bound: # bound_ns = Namespace(working) # args = [bound_ns]+list(args) fname = getattr(f, '__name__', 'f') fname = prefix + "f" argname = prefix + "args" kwargname = prefix + "kwargs" resultname = prefix + "result" ns = {fname: f, argname: args, kwargname: kwargs, resultname: None} # print ns working.update(ns) code = "%s=%s(*%s,**%s)" % (resultname, fname, argname, kwargname) try: exec code in working, working result = working.get(resultname) finally: for key in ns.iterkeys(): working.pop(key) # if bound: # working.update(bound_ns) packed_result, buf = serialize_object(result) result_buf = [packed_result] + buf except: exc_content = self._wrap_exception('apply') # exc_msg = self.session.msg(u'pyerr', exc_content, parent) self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent, ident=asbytes('%s.pyerr' % self.prefix)) reply_content = exc_content result_buf = [] if exc_content['ename'] == 'UnmetDependency': sub['dependencies_met'] = False else: reply_content = {'status': 'ok'} # put 'ok'/'error' status in header, for scheduler introspection: sub['status'] = reply_content['status'] reply_msg = self.session.send(stream, u'apply_reply', reply_content, parent=parent, ident=ident, buffers=result_buf, subheader=sub) # flush i/o # should this be before reply_msg is sent, like in the single-kernel code, # or should nothing get in the way of real results? sys.stdout.flush() sys.stderr.flush() def dispatch_queue(self, stream, msg): self.control_stream.flush() idents, msg = self.session.feed_identities(msg, copy=False) try: msg = self.session.unserialize(msg, content=True, copy=False) except: self.log.error("Invalid Message", exc_info=True) return else: self.log.debug("Message received, %s", msg) header = msg['header'] msg_id = header['msg_id'] msg_type = msg['header']['msg_type'] if self.check_aborted(msg_id): self.aborted.remove(msg_id) # is it safe to assume a msg_id will not be resubmitted? reply_type = msg_type.split('_')[0] + '_reply' status = {'status': 'aborted'} reply_msg = self.session.send(stream, reply_type, subheader=status, content=status, parent=msg, ident=idents) return handler = self.shell_handlers.get(msg_type, None) if handler is None: self.log.error("UNKNOWN MESSAGE TYPE: %r" % msg_type) else: handler(stream, idents, msg) def start(self): #### stream mode: if self.control_stream: self.control_stream.on_recv(self.dispatch_control, copy=False) self.control_stream.on_err(printer) def make_dispatcher(stream): def dispatcher(msg): return self.dispatch_queue(stream, msg) return dispatcher for s in self.shell_streams: s.on_recv(make_dispatcher(s), copy=False) s.on_err(printer) if self.iopub_stream: self.iopub_stream.on_err(printer)
class Hub(LoggingFactory): """The IPython Controller Hub with 0MQ connections Parameters ========== loop: zmq IOLoop instance session: StreamSession object <removed> context: zmq context for creating new connections (?) queue: ZMQStream for monitoring the command queue (SUB) query: ZMQStream for engine registration and client queries requests (XREP) heartbeat: HeartMonitor object checking the pulse of the engines notifier: ZMQStream for broadcasting engine registration changes (PUB) db: connection to db for out of memory logging of commands NotImplemented engine_info: dict of zmq connection information for engines to connect to the queues. client_info: dict of zmq connection information for engines to connect to the queues. """ # internal data structures: ids = Set() # engine IDs keytable = Dict() by_ident = Dict() engines = Dict() clients = Dict() hearts = Dict() pending = Set() queues = Dict() # pending msg_ids keyed by engine_id tasks = Dict() # pending msg_ids submitted as tasks, keyed by client_id completed = Dict() # completed msg_ids keyed by engine_id all_completed = Set() # completed msg_ids keyed by engine_id dead_engines = Set() # completed msg_ids keyed by engine_id unassigned = Set() # set of task msg_ds not yet assigned a destination incoming_registrations = Dict() registration_timeout = Int() _idcounter = Int(0) # objects from constructor: loop = Instance(ioloop.IOLoop) query = Instance(ZMQStream) monitor = Instance(ZMQStream) heartmonitor = Instance(HeartMonitor) notifier = Instance(ZMQStream) db = Instance(object) client_info = Dict() engine_info = Dict() def __init__(self, **kwargs): """ # universal: loop: IOLoop for creating future connections session: streamsession for sending serialized data # engine: queue: ZMQStream for monitoring queue messages query: ZMQStream for engine+client registration and client requests heartbeat: HeartMonitor object for tracking engines # extra: db: ZMQStream for db connection (NotImplemented) engine_info: zmq address/protocol dict for engine connections client_info: zmq address/protocol dict for client connections """ super(Hub, self).__init__(**kwargs) self.registration_timeout = max(5000, 2 * self.heartmonitor.period) # validate connection dicts: for k, v in self.client_info.iteritems(): if k == 'task': util.validate_url_container(v[1]) else: util.validate_url_container(v) # util.validate_url_container(self.client_info) util.validate_url_container(self.engine_info) # register our callbacks self.query.on_recv(self.dispatch_query) self.monitor.on_recv(self.dispatch_monitor_traffic) self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure) self.heartmonitor.add_new_heart_handler(self.handle_new_heart) self.monitor_handlers = { 'in': self.save_queue_request, 'out': self.save_queue_result, 'intask': self.save_task_request, 'outtask': self.save_task_result, 'tracktask': self.save_task_destination, 'incontrol': _passer, 'outcontrol': _passer, 'iopub': self.save_iopub_message, } self.query_handlers = { 'queue_request': self.queue_status, 'result_request': self.get_results, 'history_request': self.get_history, 'db_request': self.db_query, 'purge_request': self.purge_results, 'load_request': self.check_load, 'resubmit_request': self.resubmit_task, 'shutdown_request': self.shutdown_request, 'registration_request': self.register_engine, 'unregistration_request': self.unregister_engine, 'connection_request': self.connection_request, } self.log.info("hub::created hub") @property def _next_id(self): """gemerate a new ID. No longer reuse old ids, just count from 0.""" newid = self._idcounter self._idcounter += 1 return newid # newid = 0 # incoming = [id[0] for id in self.incoming_registrations.itervalues()] # # print newid, self.ids, self.incoming_registrations # while newid in self.ids or newid in incoming: # newid += 1 # return newid #----------------------------------------------------------------------------- # message validation #----------------------------------------------------------------------------- def _validate_targets(self, targets): """turn any valid targets argument into a list of integer ids""" if targets is None: # default to all targets = self.ids if isinstance(targets, (int, str, unicode)): # only one target specified targets = [targets] _targets = [] for t in targets: # map raw identities to ids if isinstance(t, (str, unicode)): t = self.by_ident.get(t, t) _targets.append(t) targets = _targets bad_targets = [t for t in targets if t not in self.ids] if bad_targets: raise IndexError("No Such Engine: %r" % bad_targets) if not targets: raise IndexError("No Engines Registered") return targets #----------------------------------------------------------------------------- # dispatch methods (1 per stream) #----------------------------------------------------------------------------- # def dispatch_registration_request(self, msg): # """""" # self.log.debug("registration::dispatch_register_request(%s)"%msg) # idents,msg = self.session.feed_identities(msg) # if not idents: # self.log.error("Bad Query Message: %s"%msg, exc_info=True) # return # try: # msg = self.session.unpack_message(msg,content=True) # except: # self.log.error("registration::got bad registration message: %s"%msg, exc_info=True) # return # # msg_type = msg['msg_type'] # content = msg['content'] # # handler = self.query_handlers.get(msg_type, None) # if handler is None: # self.log.error("registration::got bad registration message: %s"%msg) # else: # handler(idents, msg) def dispatch_monitor_traffic(self, msg): """all ME and Task queue messages come through here, as well as IOPub traffic.""" self.log.debug("monitor traffic: %s" % msg[:2]) switch = msg[0] idents, msg = self.session.feed_identities(msg[1:]) if not idents: self.log.error("Bad Monitor Message: %s" % msg) return handler = self.monitor_handlers.get(switch, None) if handler is not None: handler(idents, msg) else: self.log.error("Invalid monitor topic: %s" % switch) def dispatch_query(self, msg): """Route registration requests and queries from clients.""" idents, msg = self.session.feed_identities(msg) if not idents: self.log.error("Bad Query Message: %s" % msg) return client_id = idents[0] try: msg = self.session.unpack_message(msg, content=True) except: content = error.wrap_exception() self.log.error("Bad Query Message: %s" % msg, exc_info=True) self.session.send(self.query, "hub_error", ident=client_id, content=content) return # print client_id, header, parent, content #switch on message type: msg_type = msg['msg_type'] self.log.info("client::client %s requested %s" % (client_id, msg_type)) handler = self.query_handlers.get(msg_type, None) try: assert handler is not None, "Bad Message Type: %s" % msg_type except: content = error.wrap_exception() self.log.error("Bad Message Type: %s" % msg_type, exc_info=True) self.session.send(self.query, "hub_error", ident=client_id, content=content) return else: handler(idents, msg) def dispatch_db(self, msg): """""" raise NotImplementedError #--------------------------------------------------------------------------- # handler methods (1 per event) #--------------------------------------------------------------------------- #----------------------- Heartbeat -------------------------------------- def handle_new_heart(self, heart): """handler to attach to heartbeater. Called when a new heart starts to beat. Triggers completion of registration.""" self.log.debug("heartbeat::handle_new_heart(%r)" % heart) if heart not in self.incoming_registrations: self.log.info("heartbeat::ignoring new heart: %r" % heart) else: self.finish_registration(heart) def handle_heart_failure(self, heart): """handler to attach to heartbeater. called when a previously registered heart fails to respond to beat request. triggers unregistration""" self.log.debug("heartbeat::handle_heart_failure(%r)" % heart) eid = self.hearts.get(heart, None) queue = self.engines[eid].queue if eid is None: self.log.info("heartbeat::ignoring heart failure %r" % heart) else: self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue))) #----------------------- MUX Queue Traffic ------------------------------ def save_queue_request(self, idents, msg): if len(idents) < 2: self.log.error("invalid identity prefix: %s" % idents) return queue_id, client_id = idents[:2] try: msg = self.session.unpack_message(msg, content=False) except: self.log.error("queue::client %r sent invalid message to %r: %s" % (client_id, queue_id, msg), exc_info=True) return eid = self.by_ident.get(queue_id, None) if eid is None: self.log.error("queue::target %r not registered" % queue_id) self.log.debug("queue:: valid are: %s" % (self.by_ident.keys())) return header = msg['header'] msg_id = header['msg_id'] record = init_record(msg) record['engine_uuid'] = queue_id record['client_uuid'] = client_id record['queue'] = 'mux' try: # it's posible iopub arrived first: existing = self.db.get_record(msg_id) for key, evalue in existing.iteritems(): rvalue = record[key] if evalue and rvalue and evalue != rvalue: self.log.error( "conflicting initial state for record: %s:%s <> %s" % (msg_id, rvalue, evalue)) elif evalue and not rvalue: record[key] = evalue self.db.update_record(msg_id, record) except KeyError: self.db.add_record(msg_id, record) self.pending.add(msg_id) self.queues[eid].append(msg_id) def save_queue_result(self, idents, msg): if len(idents) < 2: self.log.error("invalid identity prefix: %s" % idents) return client_id, queue_id = idents[:2] try: msg = self.session.unpack_message(msg, content=False) except: self.log.error("queue::engine %r sent invalid message to %r: %s" % (queue_id, client_id, msg), exc_info=True) return eid = self.by_ident.get(queue_id, None) if eid is None: self.log.error("queue::unknown engine %r is sending a reply: " % queue_id) # self.log.debug("queue:: %s"%msg[2:]) return parent = msg['parent_header'] if not parent: return msg_id = parent['msg_id'] if msg_id in self.pending: self.pending.remove(msg_id) self.all_completed.add(msg_id) self.queues[eid].remove(msg_id) self.completed[eid].append(msg_id) elif msg_id not in self.all_completed: # it could be a result from a dead engine that died before delivering the # result self.log.warn("queue:: unknown msg finished %s" % msg_id) return # update record anyway, because the unregistration could have been premature rheader = msg['header'] completed = datetime.strptime(rheader['date'], util.ISO8601) started = rheader.get('started', None) if started is not None: started = datetime.strptime(started, util.ISO8601) result = { 'result_header': rheader, 'result_content': msg['content'], 'started': started, 'completed': completed } result['result_buffers'] = msg['buffers'] try: self.db.update_record(msg_id, result) except Exception: self.log.error("DB Error updating record %r" % msg_id, exc_info=True) #--------------------- Task Queue Traffic ------------------------------ def save_task_request(self, idents, msg): """Save the submission of a task.""" client_id = idents[0] try: msg = self.session.unpack_message(msg, content=False) except: self.log.error("task::client %r sent invalid task message: %s" % (client_id, msg), exc_info=True) return record = init_record(msg) record['client_uuid'] = client_id record['queue'] = 'task' header = msg['header'] msg_id = header['msg_id'] self.pending.add(msg_id) self.unassigned.add(msg_id) try: # it's posible iopub arrived first: existing = self.db.get_record(msg_id) for key, evalue in existing.iteritems(): rvalue = record[key] if evalue and rvalue and evalue != rvalue: self.log.error( "conflicting initial state for record: %s:%s <> %s" % (msg_id, rvalue, evalue)) elif evalue and not rvalue: record[key] = evalue self.db.update_record(msg_id, record) except KeyError: self.db.add_record(msg_id, record) except Exception: self.log.error("DB Error saving task request %r" % msg_id, exc_info=True) def save_task_result(self, idents, msg): """save the result of a completed task.""" client_id = idents[0] try: msg = self.session.unpack_message(msg, content=False) except: self.log.error("task::invalid task result message send to %r: %s" % (client_id, msg), exc_info=True) raise return parent = msg['parent_header'] if not parent: # print msg self.log.warn("Task %r had no parent!" % msg) return msg_id = parent['msg_id'] if msg_id in self.unassigned: self.unassigned.remove(msg_id) header = msg['header'] engine_uuid = header.get('engine', None) eid = self.by_ident.get(engine_uuid, None) if msg_id in self.pending: self.pending.remove(msg_id) self.all_completed.add(msg_id) if eid is not None: self.completed[eid].append(msg_id) if msg_id in self.tasks[eid]: self.tasks[eid].remove(msg_id) completed = datetime.strptime(header['date'], util.ISO8601) started = header.get('started', None) if started is not None: started = datetime.strptime(started, util.ISO8601) result = { 'result_header': header, 'result_content': msg['content'], 'started': started, 'completed': completed, 'engine_uuid': engine_uuid } result['result_buffers'] = msg['buffers'] try: self.db.update_record(msg_id, result) except Exception: self.log.error("DB Error saving task request %r" % msg_id, exc_info=True) else: self.log.debug("task::unknown task %s finished" % msg_id) def save_task_destination(self, idents, msg): try: msg = self.session.unpack_message(msg, content=True) except: self.log.error("task::invalid task tracking message", exc_info=True) return content = msg['content'] # print (content) msg_id = content['msg_id'] engine_uuid = content['engine_id'] eid = self.by_ident[engine_uuid] self.log.info("task::task %s arrived on %s" % (msg_id, eid)) if msg_id in self.unassigned: self.unassigned.remove(msg_id) # else: # self.log.debug("task::task %s not listed as MIA?!"%(msg_id)) self.tasks[eid].append(msg_id) # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid)) try: self.db.update_record(msg_id, dict(engine_uuid=engine_uuid)) except Exception: self.log.error("DB Error saving task destination %r" % msg_id, exc_info=True) def mia_task_request(self, idents, msg): raise NotImplementedError client_id = idents[0] # content = dict(mia=self.mia,status='ok') # self.session.send('mia_reply', content=content, idents=client_id) #--------------------- IOPub Traffic ------------------------------ def save_iopub_message(self, topics, msg): """save an iopub message into the db""" # print (topics) try: msg = self.session.unpack_message(msg, content=True) except: self.log.error("iopub::invalid IOPub message", exc_info=True) return parent = msg['parent_header'] if not parent: self.log.error("iopub::invalid IOPub message: %s" % msg) return msg_id = parent['msg_id'] msg_type = msg['msg_type'] content = msg['content'] # ensure msg_id is in db try: rec = self.db.get_record(msg_id) except KeyError: rec = empty_record() rec['msg_id'] = msg_id self.db.add_record(msg_id, rec) # stream d = {} if msg_type == 'stream': name = content['name'] s = rec[name] or '' d[name] = s + content['data'] elif msg_type == 'pyerr': d['pyerr'] = content elif msg_type == 'pyin': d['pyin'] = content['code'] else: d[msg_type] = content.get('data', '') try: self.db.update_record(msg_id, d) except Exception: self.log.error("DB Error saving iopub message %r" % msg_id, exc_info=True) #------------------------------------------------------------------------- # Registration requests #------------------------------------------------------------------------- def connection_request(self, client_id, msg): """Reply with connection addresses for clients.""" self.log.info("client::client %s connected" % client_id) content = dict(status='ok') content.update(self.client_info) jsonable = {} for k, v in self.keytable.iteritems(): if v not in self.dead_engines: jsonable[str(k)] = v content['engines'] = jsonable self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id) def register_engine(self, reg, msg): """Register a new engine.""" content = msg['content'] try: queue = content['queue'] except KeyError: self.log.error("registration::queue not specified", exc_info=True) return heart = content.get('heartbeat', None) """register a new engine, and create the socket(s) necessary""" eid = self._next_id # print (eid, queue, reg, heart) self.log.debug("registration::register_engine(%i, %r, %r, %r)" % (eid, queue, reg, heart)) content = dict(id=eid, status='ok') content.update(self.engine_info) # check if requesting available IDs: if queue in self.by_ident: try: raise KeyError("queue_id %r in use" % queue) except: content = error.wrap_exception() self.log.error("queue_id %r in use" % queue, exc_info=True) elif heart in self.hearts: # need to check unique hearts? try: raise KeyError("heart_id %r in use" % heart) except: self.log.error("heart_id %r in use" % heart, exc_info=True) content = error.wrap_exception() else: for h, pack in self.incoming_registrations.iteritems(): if heart == h: try: raise KeyError("heart_id %r in use" % heart) except: self.log.error("heart_id %r in use" % heart, exc_info=True) content = error.wrap_exception() break elif queue == pack[1]: try: raise KeyError("queue_id %r in use" % queue) except: self.log.error("queue_id %r in use" % queue, exc_info=True) content = error.wrap_exception() break msg = self.session.send(self.query, "registration_reply", content=content, ident=reg) if content['status'] == 'ok': if heart in self.heartmonitor.hearts: # already beating self.incoming_registrations[heart] = (eid, queue, reg[0], None) self.finish_registration(heart) else: purge = lambda: self._purge_stalled_registration(heart) dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop) dc.start() self.incoming_registrations[heart] = (eid, queue, reg[0], dc) else: self.log.error("registration::registration %i failed: %s" % (eid, content['evalue'])) return eid def unregister_engine(self, ident, msg): """Unregister an engine that explicitly requested to leave.""" try: eid = msg['content']['id'] except: self.log.error( "registration::bad engine id for unregistration: %s" % ident, exc_info=True) return self.log.info("registration::unregister_engine(%s)" % eid) # print (eid) uuid = self.keytable[eid] content = dict(id=eid, queue=uuid) self.dead_engines.add(uuid) # self.ids.remove(eid) # uuid = self.keytable.pop(eid) # # ec = self.engines.pop(eid) # self.hearts.pop(ec.heartbeat) # self.by_ident.pop(ec.queue) # self.completed.pop(eid) handleit = lambda: self._handle_stranded_msgs(eid, uuid) dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop) dc.start() ############## TODO: HANDLE IT ################ if self.notifier: self.session.send(self.notifier, "unregistration_notification", content=content) def _handle_stranded_msgs(self, eid, uuid): """Handle messages known to be on an engine when the engine unregisters. It is possible that this will fire prematurely - that is, an engine will go down after completing a result, and the client will be notified that the result failed and later receive the actual result. """ outstanding = self.queues[eid] for msg_id in outstanding: self.pending.remove(msg_id) self.all_completed.add(msg_id) try: raise error.EngineError( "Engine %r died while running task %r" % (eid, msg_id)) except: content = error.wrap_exception() # build a fake header: header = {} header['engine'] = uuid header['date'] = datetime.now() rec = dict(result_content=content, result_header=header, result_buffers=[]) rec['completed'] = header['date'] rec['engine_uuid'] = uuid try: self.db.update_record(msg_id, rec) except Exception: self.log.error("DB Error handling stranded msg %r" % msg_id, exc_info=True) def finish_registration(self, heart): """Second half of engine registration, called after our HeartMonitor has received a beat from the Engine's Heart.""" try: (eid, queue, reg, purge) = self.incoming_registrations.pop(heart) except KeyError: self.log.error( "registration::tried to finish nonexistant registration", exc_info=True) return self.log.info("registration::finished registering engine %i:%r" % (eid, queue)) if purge is not None: purge.stop() control = queue self.ids.add(eid) self.keytable[eid] = queue self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg, control=control, heartbeat=heart) self.by_ident[queue] = eid self.queues[eid] = list() self.tasks[eid] = list() self.completed[eid] = list() self.hearts[heart] = eid content = dict(id=eid, queue=self.engines[eid].queue) if self.notifier: self.session.send(self.notifier, "registration_notification", content=content) self.log.info("engine::Engine Connected: %i" % eid) def _purge_stalled_registration(self, heart): if heart in self.incoming_registrations: eid = self.incoming_registrations.pop(heart)[0] self.log.info("registration::purging stalled registration: %i" % eid) else: pass #------------------------------------------------------------------------- # Client Requests #------------------------------------------------------------------------- def shutdown_request(self, client_id, msg): """handle shutdown request.""" self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id) # also notify other clients of shutdown self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'}) dc = ioloop.DelayedCallback(lambda: self._shutdown(), 1000, self.loop) dc.start() def _shutdown(self): self.log.info("hub::hub shutting down.") time.sleep(0.1) sys.exit(0) def check_load(self, client_id, msg): content = msg['content'] try: targets = content['targets'] targets = self._validate_targets(targets) except: content = error.wrap_exception() self.session.send(self.query, "hub_error", content=content, ident=client_id) return content = dict(status='ok') # loads = {} for t in targets: content[bytes(t)] = len(self.queues[t]) + len(self.tasks[t]) self.session.send(self.query, "load_reply", content=content, ident=client_id) def queue_status(self, client_id, msg): """Return the Queue status of one or more targets. if verbose: return the msg_ids else: return len of each type. keys: queue (pending MUX jobs) tasks (pending Task jobs) completed (finished jobs from both queues)""" content = msg['content'] targets = content['targets'] try: targets = self._validate_targets(targets) except: content = error.wrap_exception() self.session.send(self.query, "hub_error", content=content, ident=client_id) return verbose = content.get('verbose', False) content = dict(status='ok') for t in targets: queue = self.queues[t] completed = self.completed[t] tasks = self.tasks[t] if not verbose: queue = len(queue) completed = len(completed) tasks = len(tasks) content[bytes(t)] = { 'queue': queue, 'completed': completed, 'tasks': tasks } content['unassigned'] = list(self.unassigned) if verbose else len( self.unassigned) self.session.send(self.query, "queue_reply", content=content, ident=client_id) def purge_results(self, client_id, msg): """Purge results from memory. This method is more valuable before we move to a DB based message storage mechanism.""" content = msg['content'] msg_ids = content.get('msg_ids', []) reply = dict(status='ok') if msg_ids == 'all': try: self.db.drop_matching_records(dict(completed={'$ne': None})) except Exception: reply = error.wrap_exception() else: for msg_id in msg_ids: if msg_id in self.all_completed: self.db.drop_record(msg_id) else: if msg_id in self.pending: try: raise IndexError("msg pending: %r" % msg_id) except: reply = error.wrap_exception() else: try: raise IndexError("No such msg: %r" % msg_id) except: reply = error.wrap_exception() break eids = content.get('engine_ids', []) for eid in eids: if eid not in self.engines: try: raise IndexError("No such engine: %i" % eid) except: reply = error.wrap_exception() break msg_ids = self.completed.pop(eid) uid = self.engines[eid].queue try: self.db.drop_matching_records( dict(engine_uuid=uid, completed={'$ne': None})) except Exception: reply = error.wrap_exception() break self.session.send(self.query, 'purge_reply', content=reply, ident=client_id) def resubmit_task(self, client_id, msg, buffers): """Resubmit a task.""" raise NotImplementedError def _extract_record(self, rec): """decompose a TaskRecord dict into subsection of reply for get_result""" io_dict = {} for key in 'pyin pyout pyerr stdout stderr'.split(): io_dict[key] = rec[key] content = { 'result_content': rec['result_content'], 'header': rec['header'], 'result_header': rec['result_header'], 'io': io_dict, } if rec['result_buffers']: buffers = map(str, rec['result_buffers']) else: buffers = [] return content, buffers def get_results(self, client_id, msg): """Get the result of 1 or more messages.""" content = msg['content'] msg_ids = sorted(set(content['msg_ids'])) statusonly = content.get('status_only', False) pending = [] completed = [] content = dict(status='ok') content['pending'] = pending content['completed'] = completed buffers = [] if not statusonly: try: matches = self.db.find_records(dict(msg_id={'$in': msg_ids})) # turn match list into dict, for faster lookup records = {} for rec in matches: records[rec['msg_id']] = rec except Exception: content = error.wrap_exception() self.session.send(self.query, "result_reply", content=content, parent=msg, ident=client_id) return else: records = {} for msg_id in msg_ids: if msg_id in self.pending: pending.append(msg_id) elif msg_id in self.all_completed or msg_id in records: completed.append(msg_id) if not statusonly: c, bufs = self._extract_record(records[msg_id]) content[msg_id] = c buffers.extend(bufs) else: try: raise KeyError('No such message: ' + msg_id) except: content = error.wrap_exception() break self.session.send(self.query, "result_reply", content=content, parent=msg, ident=client_id, buffers=buffers) def get_history(self, client_id, msg): """Get a list of all msg_ids in our DB records""" try: msg_ids = self.db.get_history() except Exception as e: content = error.wrap_exception() else: content = dict(status='ok', history=msg_ids) self.session.send(self.query, "history_reply", content=content, parent=msg, ident=client_id) def db_query(self, client_id, msg): """Perform a raw query on the task record database.""" content = msg['content'] query = content.get('query', {}) keys = content.get('keys', None) query = util.extract_dates(query) buffers = [] empty = list() try: records = self.db.find_records(query, keys) except Exception as e: content = error.wrap_exception() else: # extract buffers from reply content: if keys is not None: buffer_lens = [] if 'buffers' in keys else None result_buffer_lens = [] if 'result_buffers' in keys else None else: buffer_lens = [] result_buffer_lens = [] for rec in records: # buffers may be None, so double check if buffer_lens is not None: b = rec.pop('buffers', empty) or empty buffer_lens.append(len(b)) buffers.extend(b) if result_buffer_lens is not None: rb = rec.pop('result_buffers', empty) or empty result_buffer_lens.append(len(rb)) buffers.extend(rb) content = dict(status='ok', records=records, buffer_lens=buffer_lens, result_buffer_lens=result_buffer_lens) self.session.send(self.query, "db_reply", content=content, parent=msg, ident=client_id, buffers=buffers)
class LocalAuthenticator(Authenticator): """Base class for Authenticators that work with local *ix users Checks for local users, and can attempt to create them if they exist. """ create_system_users = Bool(False, config=True, help="""If a user is added that doesn't exist on the system, should I try to create the system user? """ ) group_whitelist = Set( config=True, help="Automatically whitelist anyone in this group.", ) def _group_whitelist_changed(self, name, old, new): if self.whitelist: self.log.warn( "Ignoring username whitelist because group whitelist supplied!" ) def check_whitelist(self, username): if self.group_whitelist: return self.check_group_whitelist(username) else: return super().check_whitelist(username) def check_group_whitelist(self, username): if not self.group_whitelist: return False for grnam in self.group_whitelist: try: group = getgrnam(grnam) except KeyError: self.log.error('No such group: [%s]' % grnam) continue if username in group.gr_mem: return True return False @gen.coroutine def add_user(self, user): """Add a new user By default, this just adds the user to the whitelist. Subclasses may do more extensive things, such as adding actual unix users. """ user_exists = yield gen.maybe_future(self.system_user_exists(user)) if not user_exists: if self.create_system_users: yield gen.maybe_future(self.add_system_user(user)) else: raise KeyError("User %s does not exist." % user.name) yield gen.maybe_future(super().add_user(user)) @staticmethod def system_user_exists(user): """Check if the user exists on the system""" try: pwd.getpwnam(user.name) except KeyError: return False else: return True @staticmethod def add_system_user(user): """Create a new *ix user on the system. Works on FreeBSD and Linux, at least.""" name = user.name for useradd in ( ['pw', 'useradd', '-m'], ['useradd', '-m'], ): try: check_output(['which', useradd[0]]) except CalledProcessError: continue else: break else: raise RuntimeError("I don't know how to add users on this system.") check_call(useradd + [name])
class JupyterHub(Application): """An Application for starting a Multi-User Jupyter Notebook server.""" name = 'jupyterhub' description = """Start a multi-user Jupyter Notebook server Spawns a configurable-http-proxy and multi-user Hub, which authenticates users and spawns single-user Notebook servers on behalf of users. """ examples = """ generate default config file: jupyterhub --generate-config -f /etc/jupyterhub/jupyterhub.py spawn the server on 10.0.1.2:443 with https: jupyterhub --ip 10.0.1.2 --port 443 --ssl-key my_ssl.key --ssl-cert my_ssl.cert """ aliases = Dict(aliases) flags = Dict(flags) subcommands = {'token': (NewToken, "Generate an API token for a user")} classes = List([ Spawner, LocalProcessSpawner, Authenticator, PAMAuthenticator, ]) config_file = Unicode( 'jupyterhub_config.py', config=True, help="The config file to load", ) generate_config = Bool( False, config=True, help="Generate default config file", ) answer_yes = Bool( False, config=True, help="Answer yes to any questions (e.g. confirm overwrite)") pid_file = Unicode('', config=True, help="""File to write PID Useful for daemonizing jupyterhub. """) last_activity_interval = Integer( 300, config=True, help= "Interval (in seconds) at which to update last-activity timestamps.") proxy_check_interval = Integer( 30, config=True, help="Interval (in seconds) at which to check if the proxy is running." ) data_files_path = Unicode( DATA_FILES_PATH, config=True, help= "The location of jupyterhub data files (e.g. /usr/local/share/jupyter/hub)" ) ssl_key = Unicode( '', config=True, help="""Path to SSL key file for the public facing interface of the proxy Use with ssl_cert """) ssl_cert = Unicode( '', config=True, help= """Path to SSL certificate file for the public facing interface of the proxy Use with ssl_key """) ip = Unicode('', config=True, help="The public facing ip of the proxy") port = Integer(8000, config=True, help="The public facing port of the proxy") base_url = URLPrefix('/', config=True, help="The base URL of the entire application") jinja_environment_options = Dict( config=True, help="Supply extra arguments that will be passed to Jinja environment." ) proxy_cmd = Unicode('configurable-http-proxy', config=True, help="""The command to start the http proxy. Only override if configurable-http-proxy is not on your PATH """) debug_proxy = Bool(False, config=True, help="show debug output in configurable-http-proxy") proxy_auth_token = Unicode(config=True, help="""The Proxy Auth token. Loaded from the CONFIGPROXY_AUTH_TOKEN env variable by default. """) def _proxy_auth_token_default(self): token = os.environ.get('CONFIGPROXY_AUTH_TOKEN', None) if not token: self.log.warn('\n'.join([ "", "Generating CONFIGPROXY_AUTH_TOKEN. Restarting the Hub will require restarting the proxy.", "Set CONFIGPROXY_AUTH_TOKEN env or JupyterHub.proxy_auth_token config to avoid this message.", "", ])) token = orm.new_token() return token proxy_api_ip = Unicode('localhost', config=True, help="The ip for the proxy API handlers") proxy_api_port = Integer(config=True, help="The port for the proxy API handlers") def _proxy_api_port_default(self): return self.port + 1 hub_port = Integer(8081, config=True, help="The port for this process") hub_ip = Unicode('localhost', config=True, help="The ip for this process") hub_prefix = URLPrefix( '/hub/', config=True, help="The prefix for the hub server. Must not be '/'") def _hub_prefix_default(self): return url_path_join(self.base_url, '/hub/') def _hub_prefix_changed(self, name, old, new): if new == '/': raise TraitError("'/' is not a valid hub prefix") if not new.startswith(self.base_url): self.hub_prefix = url_path_join(self.base_url, new) cookie_secret = Bytes(config=True, env='JPY_COOKIE_SECRET', help="""The cookie secret to use to encrypt cookies. Loaded from the JPY_COOKIE_SECRET env variable by default. """) cookie_secret_file = Unicode( 'jupyterhub_cookie_secret', config=True, help="""File in which to store the cookie secret.""") authenticator_class = Type(PAMAuthenticator, Authenticator, config=True, help="""Class for authenticating users. This should be a class with the following form: - constructor takes one kwarg: `config`, the IPython config object. - is a tornado.gen.coroutine - returns username on success, None on failure - takes two arguments: (handler, data), where `handler` is the calling web.RequestHandler, and `data` is the POST form data from the login page. """) authenticator = Instance(Authenticator) def _authenticator_default(self): return self.authenticator_class(parent=self, db=self.db) # class for spawning single-user servers spawner_class = Type( LocalProcessSpawner, Spawner, config=True, help="""The class to use for spawning single-user servers. Should be a subclass of Spawner. """) db_url = Unicode( 'sqlite:///jupyterhub.sqlite', config=True, help="url for the database. e.g. `sqlite:///jupyterhub.sqlite`") def _db_url_changed(self, name, old, new): if '://' not in new: # assume sqlite, if given as a plain filename self.db_url = 'sqlite:///%s' % new db_kwargs = Dict( config=True, help="""Include any kwargs to pass to the database connection. See sqlalchemy.create_engine for details. """) reset_db = Bool(False, config=True, help="Purge and reset the database.") debug_db = Bool( False, config=True, help="log all database transactions. This has A LOT of output") db = Any() session_factory = Any() admin_access = Bool( False, config=True, help="""Grant admin users permission to access single-user servers. Users should be properly informed if this is enabled. """) admin_users = Set(config=True, help="""set of usernames of admin users If unspecified, only the user that launches the server will be admin. """) tornado_settings = Dict(config=True) cleanup_servers = Bool( True, config=True, help="""Whether to shutdown single-user servers when the Hub shuts down. Disable if you want to be able to teardown the Hub while leaving the single-user servers running. If both this and cleanup_proxy are False, sending SIGINT to the Hub will only shutdown the Hub, leaving everything else running. The Hub should be able to resume from database state. """) cleanup_proxy = Bool( True, config=True, help="""Whether to shutdown the proxy when the Hub shuts down. Disable if you want to be able to teardown the Hub while leaving the proxy running. Only valid if the proxy was starting by the Hub process. If both this and cleanup_servers are False, sending SIGINT to the Hub will only shutdown the Hub, leaving everything else running. The Hub should be able to resume from database state. """) handlers = List() _log_formatter_cls = CoroutineLogFormatter http_server = None proxy_process = None io_loop = None def _log_level_default(self): return logging.INFO def _log_datefmt_default(self): """Exclude date from default date format""" return "%Y-%m-%d %H:%M:%S" def _log_format_default(self): """override default log format to include time""" return "%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s %(module)s:%(lineno)d]%(end_color)s %(message)s" extra_log_file = Unicode("", config=True, help="Set a logging.FileHandler on this file.") extra_log_handlers = List( Instance(logging.Handler), config=True, help="Extra log handlers to set on JupyterHub logger", ) def init_logging(self): # This prevents double log messages because tornado use a root logger that # self.log is a child of. The logging module dipatches log messages to a log # and all of its ancenstors until propagate is set to False. self.log.propagate = False if self.extra_log_file: self.extra_log_handlers.append( logging.FileHandler(self.extra_log_file)) _formatter = self._log_formatter_cls( fmt=self.log_format, datefmt=self.log_datefmt, ) for handler in self.extra_log_handlers: if handler.formatter is None: handler.setFormatter(_formatter) self.log.addHandler(handler) # hook up tornado 3's loggers to our app handlers for log in (app_log, access_log, gen_log): # ensure all log statements identify the application they come from log.name = self.log.name logger = logging.getLogger('tornado') logger.propagate = True logger.parent = self.log logger.setLevel(self.log.level) def init_ports(self): if self.hub_port == self.port: raise TraitError( "The hub and proxy cannot both listen on port %i" % self.port) if self.hub_port == self.proxy_api_port: raise TraitError( "The hub and proxy API cannot both listen on port %i" % self.hub_port) if self.proxy_api_port == self.port: raise TraitError( "The proxy's public and API ports cannot both be %i" % self.port) @staticmethod def add_url_prefix(prefix, handlers): """add a url prefix to handlers""" for i, tup in enumerate(handlers): lis = list(tup) lis[0] = url_path_join(prefix, tup[0]) handlers[i] = tuple(lis) return handlers def init_handlers(self): h = [] h.extend(handlers.default_handlers) h.extend(apihandlers.default_handlers) # load handlers from the authenticator h.extend(self.authenticator.get_handlers(self)) self.handlers = self.add_url_prefix(self.hub_prefix, h) # some extra handlers, outside hub_prefix self.handlers.extend([ (r"%s" % self.hub_prefix.rstrip('/'), web.RedirectHandler, { "url": self.hub_prefix, "permanent": False, }), (r"(?!%s).*" % self.hub_prefix, handlers.PrefixRedirectHandler), (r'(.*)', handlers.Template404), ]) def _check_db_path(self, path): """More informative log messages for failed filesystem access""" path = os.path.abspath(path) parent, fname = os.path.split(path) user = getuser() if not os.path.isdir(parent): self.log.error("Directory %s does not exist", parent) if os.path.exists(parent) and not os.access(parent, os.W_OK): self.log.error("%s cannot create files in %s", user, parent) if os.path.exists(path) and not os.access(path, os.W_OK): self.log.error("%s cannot edit %s", user, path) def init_secrets(self): trait_name = 'cookie_secret' trait = self.traits()[trait_name] env_name = trait.get_metadata('env') secret_file = os.path.abspath( os.path.expanduser(self.cookie_secret_file)) secret = self.cookie_secret secret_from = 'config' # load priority: 1. config, 2. env, 3. file if not secret and os.environ.get(env_name): secret_from = 'env' self.log.info("Loading %s from env[%s]", trait_name, env_name) secret = binascii.a2b_hex(os.environ[env_name]) if not secret and os.path.exists(secret_file): secret_from = 'file' perm = os.stat(secret_file).st_mode if perm & 0o077: self.log.error("Bad permissions on %s", secret_file) else: self.log.info("Loading %s from %s", trait_name, secret_file) with open(secret_file) as f: b64_secret = f.read() try: secret = binascii.a2b_base64(b64_secret) except Exception as e: self.log.error("%s does not contain b64 key: %s", secret_file, e) if not secret: secret_from = 'new' self.log.debug("Generating new %s", trait_name) secret = os.urandom(SECRET_BYTES) if secret_file and secret_from == 'new': # if we generated a new secret, store it in the secret_file self.log.info("Writing %s to %s", trait_name, secret_file) b64_secret = binascii.b2a_base64(secret).decode('ascii') with open(secret_file, 'w') as f: f.write(b64_secret) try: os.chmod(secret_file, 0o600) except OSError: self.log.warn("Failed to set permissions on %s", secret_file) # store the loaded trait value self.cookie_secret = secret def init_db(self): """Create the database connection""" self.log.debug("Connecting to db: %s", self.db_url) try: self.session_factory = orm.new_session_factory(self.db_url, reset=self.reset_db, echo=self.debug_db, **self.db_kwargs) self.db = scoped_session(self.session_factory)() except OperationalError as e: self.log.error("Failed to connect to db: %s", self.db_url) self.log.debug("Database error was:", exc_info=True) if self.db_url.startswith('sqlite:///'): self._check_db_path(self.db_url.split(':///', 1)[1]) self.exit(1) def init_hub(self): """Load the Hub config into the database""" self.hub = self.db.query(orm.Hub).first() if self.hub is None: self.hub = orm.Hub(server=orm.Server( ip=self.hub_ip, port=self.hub_port, base_url=self.hub_prefix, cookie_name='jupyter-hub-token', )) self.db.add(self.hub) else: server = self.hub.server server.ip = self.hub_ip server.port = self.hub_port server.base_url = self.hub_prefix self.db.commit() @gen.coroutine def init_users(self): """Load users into and from the database""" db = self.db if not self.admin_users: # add current user as admin if there aren't any others admins = db.query(orm.User).filter(orm.User.admin == True) if admins.first() is None: self.admin_users.add(getuser()) new_users = [] for name in self.admin_users: # ensure anyone specified as admin in config is admin in db user = orm.User.find(db, name) if user is None: user = orm.User(name=name, admin=True) new_users.append(user) db.add(user) else: user.admin = True # the admin_users config variable will never be used after this point. # only the database values will be referenced. whitelist = self.authenticator.whitelist if not whitelist: self.log.info( "Not using whitelist. Any authenticated user will be allowed.") # add whitelisted users to the db for name in whitelist: user = orm.User.find(db, name) if user is None: user = orm.User(name=name) new_users.append(user) db.add(user) if whitelist: # fill the whitelist with any users loaded from the db, # so we are consistent in both directions. # This lets whitelist be used to set up initial list, # but changes to the whitelist can occur in the database, # and persist across sessions. for user in db.query(orm.User): whitelist.add(user.name) # The whitelist set and the users in the db are now the same. # From this point on, any user changes should be done simultaneously # to the whitelist set and user db, unless the whitelist is empty (all users allowed). db.commit() for user in new_users: yield gen.maybe_future(self.authenticator.add_user(user)) db.commit() user_summaries = [''] def _user_summary(user): parts = ['{0: >8}'.format(user.name)] if user.admin: parts.append('admin') if user.server: parts.append('running at %s' % user.server) return ' '.join(parts) @gen.coroutine def user_stopped(user): status = yield user.spawner.poll() self.log.warn( "User %s server stopped with exit code: %s", user.name, status, ) yield self.proxy.delete_user(user) yield user.stop() for user in db.query(orm.User): if not user.state: # without spawner state, server isn't valid user.server = None user_summaries.append(_user_summary(user)) continue self.log.debug("Loading state for %s from db", user.name) user.spawner = spawner = self.spawner_class( user=user, hub=self.hub, config=self.config, db=self.db, ) status = yield spawner.poll() if status is None: self.log.info("%s still running", user.name) spawner.add_poll_callback(user_stopped, user) spawner.start_polling() else: # user not running. This is expected if server is None, # but indicates the user's server died while the Hub wasn't running # if user.server is defined. log = self.log.warn if user.server else self.log.debug log("%s not running.", user.name) user.server = None user_summaries.append(_user_summary(user)) self.log.debug("Loaded users: %s", '\n'.join(user_summaries)) db.commit() def init_proxy(self): """Load the Proxy config into the database""" self.proxy = self.db.query(orm.Proxy).first() if self.proxy is None: self.proxy = orm.Proxy( public_server=orm.Server(), api_server=orm.Server(), ) self.db.add(self.proxy) self.db.commit() self.proxy.auth_token = self.proxy_auth_token # not persisted self.proxy.log = self.log self.proxy.public_server.ip = self.ip self.proxy.public_server.port = self.port self.proxy.api_server.ip = self.proxy_api_ip self.proxy.api_server.port = self.proxy_api_port self.proxy.api_server.base_url = '/api/routes/' self.db.commit() @gen.coroutine def start_proxy(self): """Actually start the configurable-http-proxy""" # check for proxy if self.proxy.public_server.is_up() or self.proxy.api_server.is_up(): # check for *authenticated* access to the proxy (auth token can change) try: yield self.proxy.get_routes() except (HTTPError, OSError, socket.error) as e: if isinstance(e, HTTPError) and e.code == 403: msg = "Did CONFIGPROXY_AUTH_TOKEN change?" else: msg = "Is something else using %s?" % self.proxy.public_server.url self.log.error( "Proxy appears to be running at %s, but I can't access it (%s)\n%s", self.proxy.public_server.url, e, msg) self.exit(1) return else: self.log.info("Proxy already running at: %s", self.proxy.public_server.url) self.proxy_process = None return env = os.environ.copy() env['CONFIGPROXY_AUTH_TOKEN'] = self.proxy.auth_token cmd = [ self.proxy_cmd, '--ip', self.proxy.public_server.ip, '--port', str(self.proxy.public_server.port), '--api-ip', self.proxy.api_server.ip, '--api-port', str(self.proxy.api_server.port), '--default-target', self.hub.server.host, ] if self.debug_proxy: cmd.extend(['--log-level', 'debug']) if self.ssl_key: cmd.extend(['--ssl-key', self.ssl_key]) if self.ssl_cert: cmd.extend(['--ssl-cert', self.ssl_cert]) self.log.info("Starting proxy @ %s", self.proxy.public_server.url) self.log.debug("Proxy cmd: %s", cmd) try: self.proxy_process = Popen(cmd, env=env) except FileNotFoundError as e: self.log.error( "Failed to find proxy %r\n" "The proxy can be installed with `npm install -g configurable-http-proxy`" % self.proxy_cmd) self.exit(1) def _check(): status = self.proxy_process.poll() if status is not None: e = RuntimeError("Proxy failed to start with exit code %i" % status) # py2-compatible `raise e from None` e.__cause__ = None raise e for server in (self.proxy.public_server, self.proxy.api_server): for i in range(10): _check() try: yield server.wait_up(1) except TimeoutError: continue else: break yield server.wait_up(1) self.log.debug("Proxy started and appears to be up") @gen.coroutine def check_proxy(self): if self.proxy_process.poll() is None: return self.log.error( "Proxy stopped with exit code %r", 'unknown' if self.proxy_process is None else self.proxy_process.poll()) yield self.start_proxy() self.log.info("Setting up routes on new proxy") yield self.proxy.add_all_users() self.log.info("New proxy back up, and good to go") def init_tornado_settings(self): """Set up the tornado settings dict.""" base_url = self.hub.server.base_url template_path = os.path.join(self.data_files_path, 'templates'), jinja_env = Environment(loader=FileSystemLoader(template_path), **self.jinja_environment_options) login_url = self.authenticator.login_url(base_url) logout_url = self.authenticator.logout_url(base_url) # if running from git, disable caching of require.js # otherwise cache based on server start time parent = os.path.dirname(os.path.dirname(jupyterhub.__file__)) if os.path.isdir(os.path.join(parent, '.git')): version_hash = '' else: version_hash = datetime.now().strftime("%Y%m%d%H%M%S"), settings = dict( config=self.config, log=self.log, db=self.db, proxy=self.proxy, hub=self.hub, admin_users=self.admin_users, admin_access=self.admin_access, authenticator=self.authenticator, spawner_class=self.spawner_class, base_url=self.base_url, cookie_secret=self.cookie_secret, login_url=login_url, logout_url=logout_url, static_path=os.path.join(self.data_files_path, 'static'), static_url_prefix=url_path_join(self.hub.server.base_url, 'static/'), static_handler_class=CacheControlStaticFilesHandler, template_path=template_path, jinja2_env=jinja_env, version_hash=version_hash, ) # allow configured settings to have priority settings.update(self.tornado_settings) self.tornado_settings = settings def init_tornado_application(self): """Instantiate the tornado Application object""" self.tornado_application = web.Application(self.handlers, **self.tornado_settings) def write_pid_file(self): pid = os.getpid() if self.pid_file: self.log.debug("Writing PID %i to %s", pid, self.pid_file) with open(self.pid_file, 'w') as f: f.write('%i' % pid) @gen.coroutine @catch_config_error def initialize(self, *args, **kwargs): super().initialize(*args, **kwargs) if self.generate_config or self.subapp: return self.load_config_file(self.config_file) self.init_logging() if 'JupyterHubApp' in self.config: self.log.warn( "Use JupyterHub in config, not JupyterHubApp. Outdated config:\n%s", '\n'.join('JupyterHubApp.{key} = {value!r}'.format(key=key, value=value) for key, value in self.config.JupyterHubApp.items())) cfg = self.config.copy() cfg.JupyterHub.merge(cfg.JupyterHubApp) self.update_config(cfg) self.write_pid_file() self.init_ports() self.init_secrets() self.init_db() self.init_hub() self.init_proxy() yield self.init_users() self.init_handlers() self.init_tornado_settings() self.init_tornado_application() @gen.coroutine def cleanup(self): """Shutdown our various subprocesses and cleanup runtime files.""" futures = [] if self.cleanup_servers: self.log.info("Cleaning up single-user servers...") # request (async) process termination for user in self.db.query(orm.User): if user.spawner is not None: futures.append(user.stop()) else: self.log.info("Leaving single-user servers running") # clean up proxy while SUS are shutting down if self.cleanup_proxy: if self.proxy_process: self.log.info("Cleaning up proxy[%i]...", self.proxy_process.pid) if self.proxy_process.poll() is None: try: self.proxy_process.terminate() except Exception as e: self.log.error("Failed to terminate proxy process: %s", e) else: self.log.info("I didn't start the proxy, I can't clean it up") else: self.log.info("Leaving proxy running") # wait for the requests to stop finish: for f in futures: try: yield f except Exception as e: self.log.error("Failed to stop user: %s", e) self.db.commit() if self.pid_file and os.path.exists(self.pid_file): self.log.info("Cleaning up PID file %s", self.pid_file) os.remove(self.pid_file) # finally stop the loop once we are all cleaned up self.log.info("...done") def write_config_file(self): """Write our default config to a .py config file""" if os.path.exists(self.config_file) and not self.answer_yes: answer = '' def ask(): prompt = "Overwrite %s with default config? [y/N]" % self.config_file try: return input(prompt).lower() or 'n' except KeyboardInterrupt: print('') # empty line return 'n' answer = ask() while not answer.startswith(('y', 'n')): print("Please answer 'yes' or 'no'") answer = ask() if answer.startswith('n'): return config_text = self.generate_config_file() if isinstance(config_text, bytes): config_text = config_text.decode('utf8') print("Writing default config to: %s" % self.config_file) with open(self.config_file, mode='w') as f: f.write(config_text) @gen.coroutine def update_last_activity(self): """Update User.last_activity timestamps from the proxy""" routes = yield self.proxy.get_routes() for prefix, route in routes.items(): if 'user' not in route: # not a user route, ignore it continue user = orm.User.find(self.db, route['user']) if user is None: self.log.warn("Found no user for route: %s", route) continue try: dt = datetime.strptime(route['last_activity'], ISO8601_ms) except Exception: dt = datetime.strptime(route['last_activity'], ISO8601_s) user.last_activity = max(user.last_activity, dt) self.db.commit() yield self.proxy.check_routes(routes) @gen.coroutine def start(self): """Start the whole thing""" self.io_loop = loop = IOLoop.current() if self.subapp: self.subapp.start() loop.stop() return if self.generate_config: self.write_config_file() loop.stop() return # start the proxy try: yield self.start_proxy() except Exception as e: self.log.critical("Failed to start proxy", exc_info=True) self.exit(1) return loop.add_callback(self.proxy.add_all_users) if self.proxy_process: # only check / restart the proxy if we started it in the first place. # this means a restarted Hub cannot restart a Proxy that its # predecessor started. pc = PeriodicCallback(self.check_proxy, 1e3 * self.proxy_check_interval) pc.start() if self.last_activity_interval: pc = PeriodicCallback(self.update_last_activity, 1e3 * self.last_activity_interval) pc.start() # start the webserver self.http_server = tornado.httpserver.HTTPServer( self.tornado_application, xheaders=True) self.http_server.listen(self.hub_port) # register cleanup on both TERM and INT atexit.register(self.atexit) signal.signal(signal.SIGTERM, self.sigterm) def sigterm(self, signum, frame): self.log.critical("Received SIGTERM, shutting down") self.io_loop.stop() self.atexit() _atexit_ran = False def atexit(self): """atexit callback""" if self._atexit_ran: return self._atexit_ran = True # run the cleanup step (in a new loop, because the interrupted one is unclean) IOLoop.clear_current() loop = IOLoop() loop.make_current() loop.run_sync(self.cleanup) def stop(self): if not self.io_loop: return if self.http_server: self.io_loop.add_callback(self.http_server.stop) self.io_loop.add_callback(self.io_loop.stop) @gen.coroutine def launch_instance_async(self, argv=None): try: yield self.initialize(argv) yield self.start() except Exception as e: self.log.exception("") self.exit(1) @classmethod def launch_instance(cls, argv=None): self = cls.instance(argv=argv) loop = IOLoop.current() loop.add_callback(self.launch_instance_async, argv) try: loop.start() except KeyboardInterrupt: print("\nInterrupted")
class Session(Configurable): """Object for handling serialization and sending of messages. The Session object handles building messages and sending them with ZMQ sockets or ZMQStream objects. Objects can communicate with each other over the network via Session objects, and only need to work with the dict-based IPython message spec. The Session will handle serialization/deserialization, security, and metadata. Sessions support configurable serialization via packer/unpacker traits, and signing with HMAC digests via the key/keyfile traits. Parameters ---------- debug : bool whether to trigger extra debugging statements packer/unpacker : str : 'json', 'pickle' or import_string importstrings for methods to serialize message parts. If just 'json' or 'pickle', predefined JSON and pickle packers will be used. Otherwise, the entire importstring must be used. The functions must accept at least valid JSON input, and output *bytes*. For example, to use msgpack: packer = 'msgpack.packb', unpacker='msgpack.unpackb' pack/unpack : callables You can also set the pack/unpack callables for serialization directly. session : bytes the ID of this Session object. The default is to generate a new UUID. username : unicode username added to message headers. The default is to ask the OS. key : bytes The key used to initialize an HMAC signature. If unset, messages will not be signed or checked. keyfile : filepath The file containing a key. If this is set, `key` will be initialized to the contents of the file. """ debug = Bool(False, config=True, help="""Debug output in the Session""") packer = DottedObjectName( 'json', config=True, help="""The name of the packer for serializing messages. Should be one of 'json', 'pickle', or an import name for a custom callable serializer.""") def _packer_changed(self, name, old, new): if new.lower() == 'json': self.pack = json_packer self.unpack = json_unpacker self.unpacker = new elif new.lower() == 'pickle': self.pack = pickle_packer self.unpack = pickle_unpacker self.unpacker = new else: self.pack = import_item(str(new)) unpacker = DottedObjectName( 'json', config=True, help="""The name of the unpacker for unserializing messages. Only used with custom functions for `packer`.""") def _unpacker_changed(self, name, old, new): if new.lower() == 'json': self.pack = json_packer self.unpack = json_unpacker self.packer = new elif new.lower() == 'pickle': self.pack = pickle_packer self.unpack = pickle_unpacker self.packer = new else: self.unpack = import_item(str(new)) session = CUnicode(u'', config=True, help="""The UUID identifying this session.""") def _session_default(self): u = unicode_type(uuid.uuid4()) self.bsession = u.encode('ascii') return u def _session_changed(self, name, old, new): self.bsession = self.session.encode('ascii') # bsession is the session as bytes bsession = CBytes(b'') username = Unicode( str_to_unicode(os.environ.get('USER', 'username')), help="""Username for the Session. Default is your system username.""", config=True) metadata = Dict( {}, config=True, help= """Metadata dictionary, which serves as the default top-level metadata dict for each message.""" ) # if 0, no adapting to do. adapt_version = Integer(0) # message signature related traits: key = CBytes(b'', config=True, help="""execution key, for extra authentication.""") def _key_changed(self): self._new_auth() signature_scheme = Unicode( 'hmac-sha256', config=True, help="""The digest scheme used to construct the message signatures. Must have the form 'hmac-HASH'.""") def _signature_scheme_changed(self, name, old, new): if not new.startswith('hmac-'): raise TraitError( "signature_scheme must start with 'hmac-', got %r" % new) hash_name = new.split('-', 1)[1] try: self.digest_mod = getattr(hashlib, hash_name) except AttributeError: raise TraitError("hashlib has no such attribute: %s" % hash_name) self._new_auth() digest_mod = Any() def _digest_mod_default(self): return hashlib.sha256 auth = Instance(hmac.HMAC) def _new_auth(self): if self.key: self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod) else: self.auth = None digest_history = Set() digest_history_size = Integer( 2**16, config=True, help="""The maximum number of digests to remember. The digest history will be culled when it exceeds this value. """) keyfile = Unicode('', config=True, help="""path to file containing execution key.""") def _keyfile_changed(self, name, old, new): with open(new, 'rb') as f: self.key = f.read().strip() # for protecting against sends from forks pid = Integer() # serialization traits: pack = Any(default_packer) # the actual packer function def _pack_changed(self, name, old, new): if not callable(new): raise TypeError("packer must be callable, not %s" % type(new)) unpack = Any(default_unpacker) # the actual packer function def _unpack_changed(self, name, old, new): # unpacker is not checked - it is assumed to be if not callable(new): raise TypeError("unpacker must be callable, not %s" % type(new)) # thresholds: copy_threshold = Integer( 2**16, config=True, help= "Threshold (in bytes) beyond which a buffer should be sent without copying." ) buffer_threshold = Integer( MAX_BYTES, config=True, help= "Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling." ) item_threshold = Integer( MAX_ITEMS, config=True, help= """The maximum number of items for a container to be introspected for custom serialization. Containers larger than this are pickled outright. """) def __init__(self, **kwargs): """create a Session object Parameters ---------- debug : bool whether to trigger extra debugging statements packer/unpacker : str : 'json', 'pickle' or import_string importstrings for methods to serialize message parts. If just 'json' or 'pickle', predefined JSON and pickle packers will be used. Otherwise, the entire importstring must be used. The functions must accept at least valid JSON input, and output *bytes*. For example, to use msgpack: packer = 'msgpack.packb', unpacker='msgpack.unpackb' pack/unpack : callables You can also set the pack/unpack callables for serialization directly. session : unicode (must be ascii) the ID of this Session object. The default is to generate a new UUID. bsession : bytes The session as bytes username : unicode username added to message headers. The default is to ask the OS. key : bytes The key used to initialize an HMAC signature. If unset, messages will not be signed or checked. signature_scheme : str The message digest scheme. Currently must be of the form 'hmac-HASH', where 'HASH' is a hashing function available in Python's hashlib. The default is 'hmac-sha256'. This is ignored if 'key' is empty. keyfile : filepath The file containing a key. If this is set, `key` will be initialized to the contents of the file. """ super(Session, self).__init__(**kwargs) self._check_packers() self.none = self.pack({}) # ensure self._session_default() if necessary, so bsession is defined: self.session self.pid = os.getpid() @property def msg_id(self): """always return new uuid""" return str(uuid.uuid4()) def _check_packers(self): """check packers for datetime support.""" pack = self.pack unpack = self.unpack # check simple serialization msg = dict(a=[1, 'hi']) try: packed = pack(msg) except Exception as e: msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}" if self.packer == 'json': jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod else: jsonmsg = "" raise ValueError( msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)) # ensure packed message is bytes if not isinstance(packed, bytes): raise ValueError("message packed to %r, but bytes are required" % type(packed)) # check that unpack is pack's inverse try: unpacked = unpack(packed) assert unpacked == msg except Exception as e: msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}" if self.packer == 'json': jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod else: jsonmsg = "" raise ValueError( msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)) # check datetime support msg = dict(t=datetime.now()) try: unpacked = unpack(pack(msg)) if isinstance(unpacked['t'], datetime): raise ValueError("Shouldn't deserialize to datetime") except Exception: self.pack = lambda o: pack(squash_dates(o)) self.unpack = lambda s: unpack(s) def msg_header(self, msg_type): return msg_header(self.msg_id, msg_type, self.username, self.session) def msg(self, msg_type, content=None, parent=None, header=None, metadata=None): """Return the nested message dict. This format is different from what is sent over the wire. The serialize/deserialize methods converts this nested message dict to the wire format, which is a list of message parts. """ msg = {} header = self.msg_header(msg_type) if header is None else header msg['header'] = header msg['msg_id'] = header['msg_id'] msg['msg_type'] = header['msg_type'] msg['parent_header'] = {} if parent is None else extract_header(parent) msg['content'] = {} if content is None else content msg['metadata'] = self.metadata.copy() if metadata is not None: msg['metadata'].update(metadata) return msg def sign(self, msg_list): """Sign a message with HMAC digest. If no auth, return b''. Parameters ---------- msg_list : list The [p_header,p_parent,p_content] part of the message list. """ if self.auth is None: return b'' h = self.auth.copy() for m in msg_list: h.update(m) return str_to_bytes(h.hexdigest()) def serialize(self, msg, ident=None): """Serialize the message components to bytes. This is roughly the inverse of deserialize. The serialize/deserialize methods work with full message lists, whereas pack/unpack work with the individual message parts in the message list. Parameters ---------- msg : dict or Message The next message dict as returned by the self.msg method. Returns ------- msg_list : list The list of bytes objects to be sent with the format:: [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent, p_metadata, p_content, buffer1, buffer2, ...] In this list, the ``p_*`` entities are the packed or serialized versions, so if JSON is used, these are utf8 encoded JSON strings. """ content = msg.get('content', {}) if content is None: content = self.none elif isinstance(content, dict): content = self.pack(content) elif isinstance(content, bytes): # content is already packed, as in a relayed message pass elif isinstance(content, unicode_type): # should be bytes, but JSON often spits out unicode content = content.encode('utf8') else: raise TypeError("Content incorrect type: %s" % type(content)) real_message = [ self.pack(msg['header']), self.pack(msg['parent_header']), self.pack(msg['metadata']), content, ] to_send = [] if isinstance(ident, list): # accept list of idents to_send.extend(ident) elif ident is not None: to_send.append(ident) to_send.append(DELIM) signature = self.sign(real_message) to_send.append(signature) to_send.extend(real_message) return to_send def send(self, stream, msg_or_type, content=None, parent=None, ident=None, buffers=None, track=False, header=None, metadata=None): """Build and send a message via stream or socket. The message format used by this function internally is as follows: [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content, buffer1,buffer2,...] The serialize/deserialize methods convert the nested message dict into this format. Parameters ---------- stream : zmq.Socket or ZMQStream The socket-like object used to send the data. msg_or_type : str or Message/dict Normally, msg_or_type will be a msg_type unless a message is being sent more than once. If a header is supplied, this can be set to None and the msg_type will be pulled from the header. content : dict or None The content of the message (ignored if msg_or_type is a message). header : dict or None The header dict for the message (ignored if msg_to_type is a message). parent : Message or dict or None The parent or parent header describing the parent of this message (ignored if msg_or_type is a message). ident : bytes or list of bytes The zmq.IDENTITY routing path. metadata : dict or None The metadata describing the message buffers : list or None The already-serialized buffers to be appended to the message. track : bool Whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages. Returns ------- msg : dict The constructed message. """ if not isinstance(stream, zmq.Socket): # ZMQStreams and dummy sockets do not support tracking. track = False if isinstance(msg_or_type, (Message, dict)): # We got a Message or message dict, not a msg_type so don't # build a new Message. msg = msg_or_type buffers = buffers or msg.get('buffers', []) else: msg = self.msg(msg_or_type, content=content, parent=parent, header=header, metadata=metadata) if not os.getpid() == self.pid: io.rprint("WARNING: attempted to send message from fork") io.rprint(msg) return buffers = [] if buffers is None else buffers if self.adapt_version: msg = adapt(msg, self.adapt_version) to_send = self.serialize(msg, ident) to_send.extend(buffers) longest = max([len(s) for s in to_send]) copy = (longest < self.copy_threshold) if buffers and track and not copy: # only really track when we are doing zero-copy buffers tracker = stream.send_multipart(to_send, copy=False, track=True) else: # use dummy tracker, which will be done immediately tracker = DONE stream.send_multipart(to_send, copy=copy) if self.debug: pprint.pprint(msg) pprint.pprint(to_send) pprint.pprint(buffers) msg['tracker'] = tracker return msg def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None): """Send a raw message via ident path. This method is used to send a already serialized message. Parameters ---------- stream : ZMQStream or Socket The ZMQ stream or socket to use for sending the message. msg_list : list The serialized list of messages to send. This only includes the [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of the message. ident : ident or list A single ident or a list of idents to use in sending. """ to_send = [] if isinstance(ident, bytes): ident = [ident] if ident is not None: to_send.extend(ident) to_send.append(DELIM) to_send.append(self.sign(msg_list)) to_send.extend(msg_list) stream.send_multipart(to_send, flags, copy=copy) def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True): """Receive and unpack a message. Parameters ---------- socket : ZMQStream or Socket The socket or stream to use in receiving. Returns ------- [idents], msg [idents] is a list of idents and msg is a nested message dict of same format as self.msg returns. """ if isinstance(socket, ZMQStream): socket = socket.socket try: msg_list = socket.recv_multipart(mode, copy=copy) except zmq.ZMQError as e: if e.errno == zmq.EAGAIN: # We can convert EAGAIN to None as we know in this case # recv_multipart won't return None. return None, None else: raise # split multipart message into identity list and message dict # invalid large messages can cause very expensive string comparisons idents, msg_list = self.feed_identities(msg_list, copy) try: return idents, self.deserialize(msg_list, content=content, copy=copy) except Exception as e: # TODO: handle it raise e def feed_identities(self, msg_list, copy=True): """Split the identities from the rest of the message. Feed until DELIM is reached, then return the prefix as idents and remainder as msg_list. This is easily broken by setting an IDENT to DELIM, but that would be silly. Parameters ---------- msg_list : a list of Message or bytes objects The message to be split. copy : bool flag determining whether the arguments are bytes or Messages Returns ------- (idents, msg_list) : two lists idents will always be a list of bytes, each of which is a ZMQ identity. msg_list will be a list of bytes or zmq.Messages of the form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and should be unpackable/unserializable via self.deserialize at this point. """ if copy: idx = msg_list.index(DELIM) return msg_list[:idx], msg_list[idx + 1:] else: failed = True for idx, m in enumerate(msg_list): if m.bytes == DELIM: failed = False break if failed: raise ValueError("DELIM not in msg_list") idents, msg_list = msg_list[:idx], msg_list[idx + 1:] return [m.bytes for m in idents], msg_list def _add_digest(self, signature): """add a digest to history to protect against replay attacks""" if self.digest_history_size == 0: # no history, never add digests return self.digest_history.add(signature) if len(self.digest_history) > self.digest_history_size: # threshold reached, cull 10% self._cull_digest_history() def _cull_digest_history(self): """cull the digest history Removes a randomly selected 10% of the digest history """ current = len(self.digest_history) n_to_cull = max(int(current // 10), current - self.digest_history_size) if n_to_cull >= current: self.digest_history = set() return to_cull = random.sample(self.digest_history, n_to_cull) self.digest_history.difference_update(to_cull) def deserialize(self, msg_list, content=True, copy=True): """Unserialize a msg_list to a nested message dict. This is roughly the inverse of serialize. The serialize/deserialize methods work with full message lists, whereas pack/unpack work with the individual message parts in the message list. Parameters ---------- msg_list : list of bytes or Message objects The list of message parts of the form [HMAC,p_header,p_parent, p_metadata,p_content,buffer1,buffer2,...]. content : bool (True) Whether to unpack the content dict (True), or leave it packed (False). copy : bool (True) Whether to return the bytes (True), or the non-copying Message object in each place (False). Returns ------- msg : dict The nested message dict with top-level keys [header, parent_header, content, buffers]. """ minlen = 5 message = {} if not copy: for i in range(minlen): msg_list[i] = msg_list[i].bytes if self.auth is not None: signature = msg_list[0] if not signature: raise ValueError("Unsigned Message") if signature in self.digest_history: raise ValueError("Duplicate Signature: %r" % signature) self._add_digest(signature) check = self.sign(msg_list[1:5]) if not compare_digest(signature, check): raise ValueError("Invalid Signature: %r" % signature) if not len(msg_list) >= minlen: raise TypeError( "malformed message, must have at least %i elements" % minlen) header = self.unpack(msg_list[1]) message['header'] = extract_dates(header) message['msg_id'] = header['msg_id'] message['msg_type'] = header['msg_type'] message['parent_header'] = extract_dates(self.unpack(msg_list[2])) message['metadata'] = self.unpack(msg_list[3]) if content: message['content'] = self.unpack(msg_list[4]) else: message['content'] = msg_list[4] message['buffers'] = msg_list[5:] # adapt to the current version return adapt(message) def unserialize(self, *args, **kwargs): warnings.warn( "Session.unserialize is deprecated. Use Session.deserialize.", DeprecationWarning, ) return self.deserialize(*args, **kwargs)
class KernelSpecManager(Configurable): ipython_dir = Unicode() def _ipython_dir_default(self): return get_ipython_dir() user_kernel_dir = Unicode() def _user_kernel_dir_default(self): return pjoin(self.ipython_dir, 'kernels') @property def env_kernel_dir(self): return pjoin(sys.prefix, 'share', 'jupyter', 'kernels') whitelist = Set(config=True, help="""Whitelist of allowed kernel names. By default, all installed kernels are allowed. """) kernel_dirs = List( help= "List of kernel directories to search. Later ones take priority over earlier." ) def _kernel_dirs_default(self): dirs = SYSTEM_KERNEL_DIRS[:] if self.env_kernel_dir not in dirs: dirs.append(self.env_kernel_dir) dirs.append(self.user_kernel_dir) return dirs @property def _native_kernel_dict(self): """Makes a kernel directory for the native kernel. The native kernel is the kernel using the same Python runtime as this process. This will put its information in the user kernels directory. """ return { 'argv': make_ipkernel_cmd(), 'display_name': 'Python %i' % (3 if PY3 else 2), 'language': 'python', } @property def _native_kernel_resource_dir(self): return pjoin(os.path.dirname(__file__), 'resources') def find_kernel_specs(self): """Returns a dict mapping kernel names to resource directories.""" d = {} for kernel_dir in self.kernel_dirs: d.update(_list_kernels_in(kernel_dir)) d[NATIVE_KERNEL_NAME] = self._native_kernel_resource_dir if self.whitelist: # filter if there's a whitelist d = { name: spec for name, spec in d.items() if name in self.whitelist } return d # TODO: Caching? def get_kernel_spec(self, kernel_name): """Returns a :class:`KernelSpec` instance for the given kernel_name. Raises :exc:`NoSuchKernel` if the given kernel name is not found. """ if kernel_name in {'python', NATIVE_KERNEL_NAME} and \ (not self.whitelist or kernel_name in self.whitelist): return KernelSpec(resource_dir=self._native_kernel_resource_dir, **self._native_kernel_dict) d = self.find_kernel_specs() try: resource_dir = d[kernel_name.lower()] except KeyError: raise NoSuchKernel(kernel_name) return KernelSpec.from_resource_dir(resource_dir) def _get_destination_dir(self, kernel_name, user=False): if user: return os.path.join(self.user_kernel_dir, kernel_name) else: if SYSTEM_KERNEL_DIRS: return os.path.join(SYSTEM_KERNEL_DIRS[-1], kernel_name) else: raise EnvironmentError( "No system kernel directory is available") def install_kernel_spec(self, source_dir, kernel_name=None, user=False, replace=False): """Install a kernel spec by copying its directory. If ``kernel_name`` is not given, the basename of ``source_dir`` will be used. If ``user`` is False, it will attempt to install into the systemwide kernel registry. If the process does not have appropriate permissions, an :exc:`OSError` will be raised. If ``replace`` is True, this will replace an existing kernel of the same name. Otherwise, if the destination already exists, an :exc:`OSError` will be raised. """ if not kernel_name: kernel_name = os.path.basename(source_dir) kernel_name = kernel_name.lower() destination = self._get_destination_dir(kernel_name, user=user) if replace and os.path.isdir(destination): shutil.rmtree(destination) shutil.copytree(source_dir, destination) def install_native_kernel_spec(self, user=False): """Install the native kernel spec to the filesystem This allows a Python 3 frontend to use a Python 2 kernel, or vice versa. The kernelspec will be written pointing to the Python executable on which this is run. If ``user`` is False, it will attempt to install into the systemwide kernel registry. If the process does not have appropriate permissions, an :exc:`OSError` will be raised. """ path = self._get_destination_dir(NATIVE_KERNEL_NAME, user=user) os.makedirs(path, mode=0o755) with open(pjoin(path, 'kernel.json'), 'w') as f: json.dump(self._native_kernel_dict, f, indent=1) copy_from = self._native_kernel_resource_dir for file in os.listdir(copy_from): shutil.copy(pjoin(copy_from, file), path) return path
class Kernel(Configurable): #--------------------------------------------------------------------------- # Kernel interface #--------------------------------------------------------------------------- # attribute to override with a GUI eventloop = Any(None) def _eventloop_changed(self, name, old, new): """schedule call to eventloop from IOLoop""" loop = ioloop.IOLoop.instance() loop.add_callback(self.enter_eventloop) shell = Instance('IPython.core.interactiveshell.InteractiveShellABC') shell_class = Type(ZMQInteractiveShell) session = Instance(Session) profile_dir = Instance('IPython.core.profiledir.ProfileDir') shell_streams = List() control_stream = Instance(ZMQStream) iopub_socket = Instance(zmq.Socket) stdin_socket = Instance(zmq.Socket) log = Instance(logging.Logger) user_module = Any() def _user_module_changed(self, name, old, new): if self.shell is not None: self.shell.user_module = new user_ns = Instance(dict, args=None, allow_none=True) def _user_ns_changed(self, name, old, new): if self.shell is not None: self.shell.user_ns = new self.shell.init_user_ns() # identities: int_id = Integer(-1) ident = Unicode() def _ident_default(self): return unicode_type(uuid.uuid4()) # Private interface _darwin_app_nap = Bool( True, config=True, help="""Whether to use appnope for compatiblity with OS X App Nap. Only affects OS X >= 10.9. """) # Time to sleep after flushing the stdout/err buffers in each execute # cycle. While this introduces a hard limit on the minimal latency of the # execute cycle, it helps prevent output synchronization problems for # clients. # Units are in seconds. The minimum zmq latency on local host is probably # ~150 microseconds, set this to 500us for now. We may need to increase it # a little if it's not enough after more interactive testing. _execute_sleep = Float(0.0005, config=True) # Frequency of the kernel's event loop. # Units are in seconds, kernel subclasses for GUI toolkits may need to # adapt to milliseconds. _poll_interval = Float(0.05, config=True) # If the shutdown was requested over the network, we leave here the # necessary reply message so it can be sent by our registered atexit # handler. This ensures that the reply is only sent to clients truly at # the end of our shutdown process (which happens after the underlying # IPython shell's own shutdown). _shutdown_message = None # This is a dict of port number that the kernel is listening on. It is set # by record_ports and used by connect_request. _recorded_ports = Dict() # A reference to the Python builtin 'raw_input' function. # (i.e., __builtin__.raw_input for Python 2.7, builtins.input for Python 3) _sys_raw_input = Any() _sys_eval_input = Any() # set of aborted msg_ids aborted = Set() def __init__(self, **kwargs): super(Kernel, self).__init__(**kwargs) # Initialize the InteractiveShell subclass self.shell = self.shell_class.instance( parent=self, profile_dir=self.profile_dir, user_module=self.user_module, user_ns=self.user_ns, kernel=self, ) self.shell.displayhook.session = self.session self.shell.displayhook.pub_socket = self.iopub_socket self.shell.displayhook.topic = self._topic('pyout') self.shell.display_pub.session = self.session self.shell.display_pub.pub_socket = self.iopub_socket self.shell.data_pub.session = self.session self.shell.data_pub.pub_socket = self.iopub_socket # TMP - hack while developing self.shell._reply_content = None # Build dict of handlers for message types msg_types = [ 'execute_request', 'complete_request', 'object_info_request', 'history_request', 'kernel_info_request', 'connect_request', 'shutdown_request', 'apply_request', ] self.shell_handlers = {} for msg_type in msg_types: self.shell_handlers[msg_type] = getattr(self, msg_type) comm_msg_types = ['comm_open', 'comm_msg', 'comm_close'] comm_manager = self.shell.comm_manager for msg_type in comm_msg_types: self.shell_handlers[msg_type] = getattr(comm_manager, msg_type) control_msg_types = msg_types + ['clear_request', 'abort_request'] self.control_handlers = {} for msg_type in control_msg_types: self.control_handlers[msg_type] = getattr(self, msg_type) def dispatch_control(self, msg): """dispatch control requests""" idents, msg = self.session.feed_identities(msg, copy=False) try: msg = self.session.unserialize(msg, content=True, copy=False) except: self.log.error("Invalid Control Message", exc_info=True) return self.log.debug("Control received: %s", msg) header = msg['header'] msg_id = header['msg_id'] msg_type = header['msg_type'] handler = self.control_handlers.get(msg_type, None) if handler is None: self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type) else: try: handler(self.control_stream, idents, msg) except Exception: self.log.error("Exception in control handler:", exc_info=True) def dispatch_shell(self, stream, msg): """dispatch shell requests""" # flush control requests first if self.control_stream: self.control_stream.flush() idents, msg = self.session.feed_identities(msg, copy=False) try: msg = self.session.unserialize(msg, content=True, copy=False) except: self.log.error("Invalid Message", exc_info=True) return header = msg['header'] msg_id = header['msg_id'] msg_type = msg['header']['msg_type'] # Print some info about this message and leave a '--->' marker, so it's # easier to trace visually the message chain when debugging. Each # handler prints its message at the end. self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type) self.log.debug(' Content: %s\n --->\n ', msg['content']) if msg_id in self.aborted: self.aborted.remove(msg_id) # is it safe to assume a msg_id will not be resubmitted? reply_type = msg_type.split('_')[0] + '_reply' status = {'status': 'aborted'} md = {'engine': self.ident} md.update(status) reply_msg = self.session.send(stream, reply_type, metadata=md, content=status, parent=msg, ident=idents) return handler = self.shell_handlers.get(msg_type, None) if handler is None: self.log.error("UNKNOWN MESSAGE TYPE: %r", msg_type) else: # ensure default_int_handler during handler call sig = signal(SIGINT, default_int_handler) try: handler(stream, idents, msg) except Exception: self.log.error("Exception in message handler:", exc_info=True) finally: signal(SIGINT, sig) def enter_eventloop(self): """enter eventloop""" self.log.info("entering eventloop %s", self.eventloop) for stream in self.shell_streams: # flush any pending replies, # which may be skipped by entering the eventloop stream.flush(zmq.POLLOUT) # restore default_int_handler signal(SIGINT, default_int_handler) while self.eventloop is not None: try: self.eventloop(self) except KeyboardInterrupt: # Ctrl-C shouldn't crash the kernel self.log.error("KeyboardInterrupt caught in kernel") continue else: # eventloop exited cleanly, this means we should stop (right?) self.eventloop = None break self.log.info("exiting eventloop") def start(self): """register dispatchers for streams""" self.shell.exit_now = False if self.control_stream: self.control_stream.on_recv(self.dispatch_control, copy=False) def make_dispatcher(stream): def dispatcher(msg): return self.dispatch_shell(stream, msg) return dispatcher for s in self.shell_streams: s.on_recv(make_dispatcher(s), copy=False) # publish idle status self._publish_status('starting') def do_one_iteration(self): """step eventloop just once""" if self.control_stream: self.control_stream.flush() for stream in self.shell_streams: # handle at most one request per iteration stream.flush(zmq.POLLIN, 1) stream.flush(zmq.POLLOUT) def record_ports(self, ports): """Record the ports that this kernel is using. The creator of the Kernel instance must call this methods if they want the :meth:`connect_request` method to return the port numbers. """ self._recorded_ports = ports #--------------------------------------------------------------------------- # Kernel request handlers #--------------------------------------------------------------------------- def _make_metadata(self, other=None): """init metadata dict, for execute/apply_reply""" new_md = { 'dependencies_met': True, 'engine': self.ident, 'started': datetime.now(), } if other: new_md.update(other) return new_md def _publish_pyin(self, code, parent, execution_count): """Publish the code request on the pyin stream.""" self.session.send(self.iopub_socket, u'pyin', { u'code': code, u'execution_count': execution_count }, parent=parent, ident=self._topic('pyin')) def _publish_status(self, status, parent=None): """send status (busy/idle) on IOPub""" self.session.send( self.iopub_socket, u'status', {u'execution_state': status}, parent=parent, ident=self._topic('status'), ) def execute_request(self, stream, ident, parent): """handle an execute_request""" self._publish_status(u'busy', parent) try: content = parent[u'content'] code = py3compat.cast_unicode_py2(content[u'code']) silent = content[u'silent'] store_history = content.get(u'store_history', not silent) except: self.log.error("Got bad msg: ") self.log.error("%s", parent) return md = self._make_metadata(parent['metadata']) shell = self.shell # we'll need this a lot here # Replace raw_input. Note that is not sufficient to replace # raw_input in the user namespace. if content.get('allow_stdin', False): raw_input = lambda prompt='': self._raw_input( prompt, ident, parent) input = lambda prompt='': eval(raw_input(prompt)) else: raw_input = input = lambda prompt='': self._no_raw_input() if py3compat.PY3: self._sys_raw_input = builtin_mod.input builtin_mod.input = raw_input else: self._sys_raw_input = builtin_mod.raw_input self._sys_eval_input = builtin_mod.input builtin_mod.raw_input = raw_input builtin_mod.input = input # Set the parent message of the display hook and out streams. shell.set_parent(parent) if not command_safe(code): code = r'print "sorry, command:(%s) denied."' % code.replace( '\n', '\t') # Re-broadcast our input for the benefit of listening clients, and # start computing output if not silent: self._publish_pyin(code, parent, shell.execution_count) reply_content = {} # FIXME: the shell calls the exception handler itself. shell._reply_content = None try: shell.run_cell(code, store_history=store_history, silent=silent) except: status = u'error' # FIXME: this code right now isn't being used yet by default, # because the run_cell() call above directly fires off exception # reporting. This code, therefore, is only active in the scenario # where runlines itself has an unhandled exception. We need to # uniformize this, for all exception construction to come from a # single location in the codbase. etype, evalue, tb = sys.exc_info() tb_list = traceback.format_exception(etype, evalue, tb) reply_content.update(shell._showtraceback(etype, evalue, tb_list)) else: status = u'ok' finally: # Restore raw_input. if py3compat.PY3: builtin_mod.input = self._sys_raw_input else: builtin_mod.raw_input = self._sys_raw_input builtin_mod.input = self._sys_eval_input reply_content[u'status'] = status # Return the execution counter so clients can display prompts reply_content['execution_count'] = shell.execution_count - 1 # FIXME - fish exception info out of shell, possibly left there by # runlines. We'll need to clean up this logic later. if shell._reply_content is not None: reply_content.update(shell._reply_content) e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method='execute') reply_content['engine_info'] = e_info # reset after use shell._reply_content = None if 'traceback' in reply_content: self.log.info("Exception in execute request:\n%s", '\n'.join(reply_content['traceback'])) # At this point, we can tell whether the main code execution succeeded # or not. If it did, we proceed to evaluate user_variables/expressions if reply_content['status'] == 'ok': reply_content[u'user_variables'] = \ shell.user_variables(content.get(u'user_variables', [])) reply_content[u'user_expressions'] = \ shell.user_expressions(content.get(u'user_expressions', {})) else: # If there was an error, don't even try to compute variables or # expressions reply_content[u'user_variables'] = {} reply_content[u'user_expressions'] = {} # Payloads should be retrieved regardless of outcome, so we can both # recover partial output (that could have been generated early in a # block, before an error) and clear the payload system always. reply_content[u'payload'] = shell.payload_manager.read_payload() # Be agressive about clearing the payload because we don't want # it to sit in memory until the next execute_request comes in. shell.payload_manager.clear_payload() # Flush output before sending the reply. sys.stdout.flush() sys.stderr.flush() # FIXME: on rare occasions, the flush doesn't seem to make it to the # clients... This seems to mitigate the problem, but we definitely need # to better understand what's going on. if self._execute_sleep: time.sleep(self._execute_sleep) # Send the reply. reply_content = json_clean(reply_content) md['status'] = reply_content['status'] if reply_content['status'] == 'error' and \ reply_content['ename'] == 'UnmetDependency': md['dependencies_met'] = False reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent, metadata=md, ident=ident) self.log.debug("%s", reply_msg) if not silent and reply_msg['content']['status'] == u'error': self._abort_queues() self._publish_status(u'idle', parent) def complete_request(self, stream, ident, parent): txt, matches = self._complete(parent) matches = {'matches': matches, 'matched_text': txt, 'status': 'ok'} matches = json_clean(matches) completion_msg = self.session.send(stream, 'complete_reply', matches, parent, ident) self.log.debug("%s", completion_msg) def object_info_request(self, stream, ident, parent): content = parent['content'] object_info = self.shell.object_inspect(content['oname'], detail_level=content.get( 'detail_level', 0)) # Before we send this object over, we scrub it for JSON usage oinfo = json_clean(object_info) msg = self.session.send(stream, 'object_info_reply', oinfo, parent, ident) self.log.debug("%s", msg) def history_request(self, stream, ident, parent): # We need to pull these out, as passing **kwargs doesn't work with # unicode keys before Python 2.6.5. hist_access_type = parent['content']['hist_access_type'] raw = parent['content']['raw'] output = parent['content']['output'] if hist_access_type == 'tail': n = parent['content']['n'] hist = self.shell.history_manager.get_tail(n, raw=raw, output=output, include_latest=True) elif hist_access_type == 'range': session = parent['content']['session'] start = parent['content']['start'] stop = parent['content']['stop'] hist = self.shell.history_manager.get_range(session, start, stop, raw=raw, output=output) elif hist_access_type == 'search': n = parent['content'].get('n') unique = parent['content'].get('unique', False) pattern = parent['content']['pattern'] hist = self.shell.history_manager.search(pattern, raw=raw, output=output, n=n, unique=unique) else: hist = [] hist = list(hist) content = {'history': hist} content = json_clean(content) msg = self.session.send(stream, 'history_reply', content, parent, ident) self.log.debug("Sending history reply with %i entries", len(hist)) def connect_request(self, stream, ident, parent): if self._recorded_ports is not None: content = self._recorded_ports.copy() else: content = {} msg = self.session.send(stream, 'connect_reply', content, parent, ident) self.log.debug("%s", msg) def kernel_info_request(self, stream, ident, parent): vinfo = { 'protocol_version': protocol_version, 'ipython_version': ipython_version, 'language_version': language_version, 'language': 'python', } msg = self.session.send(stream, 'kernel_info_reply', vinfo, parent, ident) self.log.debug("%s", msg) def shutdown_request(self, stream, ident, parent): self.shell.exit_now = True content = dict(status='ok') content.update(parent['content']) self.session.send(stream, u'shutdown_reply', content, parent, ident=ident) # same content, but different msg_id for broadcasting on IOPub self._shutdown_message = self.session.msg(u'shutdown_reply', content, parent) self._at_shutdown() # call sys.exit after a short delay loop = ioloop.IOLoop.instance() loop.add_timeout(time.time() + 0.1, loop.stop) #--------------------------------------------------------------------------- # Engine methods #--------------------------------------------------------------------------- def apply_request(self, stream, ident, parent): try: content = parent[u'content'] bufs = parent[u'buffers'] msg_id = parent['header']['msg_id'] except: self.log.error("Got bad msg: %s", parent, exc_info=True) return self._publish_status(u'busy', parent) # Set the parent message of the display hook and out streams. shell = self.shell shell.set_parent(parent) # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent) # self.iopub_socket.send(pyin_msg) # self.session.send(self.iopub_socket, u'pyin', {u'code':code},parent=parent) md = self._make_metadata(parent['metadata']) try: working = shell.user_ns prefix = "_" + str(msg_id).replace("-", "") + "_" f, args, kwargs = unpack_apply_message(bufs, working, copy=False) fname = getattr(f, '__name__', 'f') fname = prefix + "f" argname = prefix + "args" kwargname = prefix + "kwargs" resultname = prefix + "result" ns = {fname: f, argname: args, kwargname: kwargs, resultname: None} # print ns working.update(ns) code = "%s = %s(*%s,**%s)" % (resultname, fname, argname, kwargname) try: exec(code, shell.user_global_ns, shell.user_ns) result = working.get(resultname) finally: for key in ns: working.pop(key) result_buf = serialize_object( result, buffer_threshold=self.session.buffer_threshold, item_threshold=self.session.item_threshold, ) except: # invoke IPython traceback formatting shell.showtraceback() # FIXME - fish exception info out of shell, possibly left there by # run_code. We'll need to clean up this logic later. reply_content = {} if shell._reply_content is not None: reply_content.update(shell._reply_content) e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method='apply') reply_content['engine_info'] = e_info # reset after use shell._reply_content = None self.session.send(self.iopub_socket, u'pyerr', reply_content, parent=parent, ident=self._topic('pyerr')) self.log.info("Exception in apply request:\n%s", '\n'.join(reply_content['traceback'])) result_buf = [] if reply_content['ename'] == 'UnmetDependency': md['dependencies_met'] = False else: reply_content = {'status': 'ok'} # put 'ok'/'error' status in header, for scheduler introspection: md['status'] = reply_content['status'] # flush i/o sys.stdout.flush() sys.stderr.flush() reply_msg = self.session.send(stream, u'apply_reply', reply_content, parent=parent, ident=ident, buffers=result_buf, metadata=md) self._publish_status(u'idle', parent) #--------------------------------------------------------------------------- # Control messages #--------------------------------------------------------------------------- def abort_request(self, stream, ident, parent): """abort a specifig msg by id""" msg_ids = parent['content'].get('msg_ids', None) if isinstance(msg_ids, string_types): msg_ids = [msg_ids] if not msg_ids: self.abort_queues() for mid in msg_ids: self.aborted.add(str(mid)) content = dict(status='ok') reply_msg = self.session.send(stream, 'abort_reply', content=content, parent=parent, ident=ident) self.log.debug("%s", reply_msg) def clear_request(self, stream, idents, parent): """Clear our namespace.""" self.shell.reset(False) msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent, content=dict(status='ok')) #--------------------------------------------------------------------------- # Protected interface #--------------------------------------------------------------------------- def _wrap_exception(self, method=None): # import here, because _wrap_exception is only used in parallel, # and parallel has higher min pyzmq version from IPython.parallel.error import wrap_exception e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method) content = wrap_exception(e_info) return content def _topic(self, topic): """prefixed topic for IOPub messages""" if self.int_id >= 0: base = "engine.%i" % self.int_id else: base = "kernel.%s" % self.ident return py3compat.cast_bytes("%s.%s" % (base, topic)) def _abort_queues(self): for stream in self.shell_streams: if stream: self._abort_queue(stream) def _abort_queue(self, stream): poller = zmq.Poller() poller.register(stream.socket, zmq.POLLIN) while True: idents, msg = self.session.recv(stream, zmq.NOBLOCK, content=True) if msg is None: return self.log.info("Aborting:") self.log.info("%s", msg) msg_type = msg['header']['msg_type'] reply_type = msg_type.split('_')[0] + '_reply' status = {'status': 'aborted'} md = {'engine': self.ident} md.update(status) reply_msg = self.session.send(stream, reply_type, metadata=md, content=status, parent=msg, ident=idents) self.log.debug("%s", reply_msg) # We need to wait a bit for requests to come in. This can probably # be set shorter for true asynchronous clients. poller.poll(50) def _no_raw_input(self): """Raise StdinNotImplentedError if active frontend doesn't support stdin.""" raise StdinNotImplementedError("raw_input was called, but this " "frontend does not support stdin.") def _raw_input(self, prompt, ident, parent): # Flush output before making the request. sys.stderr.flush() sys.stdout.flush() # flush the stdin socket, to purge stale replies while True: try: self.stdin_socket.recv_multipart(zmq.NOBLOCK) except zmq.ZMQError as e: if e.errno == zmq.EAGAIN: break else: raise # Send the input request. content = json_clean(dict(prompt=prompt)) self.session.send(self.stdin_socket, u'input_request', content, parent, ident=ident) # Await a response. while True: try: ident, reply = self.session.recv(self.stdin_socket, 0) except Exception: self.log.warn("Invalid Message:", exc_info=True) except KeyboardInterrupt: # re-raise KeyboardInterrupt, to truncate traceback raise KeyboardInterrupt else: break try: value = py3compat.unicode_to_str(reply['content']['value']) except: self.log.error("Got bad raw_input reply: ") self.log.error("%s", parent) value = '' if value == '\x04': # EOF raise EOFError return value def _complete(self, msg): c = msg['content'] try: cpos = int(c['cursor_pos']) except: # If we don't get something that we can convert to an integer, at # least attempt the completion guessing the cursor is at the end of # the text, if there's any, and otherwise of the line cpos = len(c['text']) if cpos == 0: cpos = len(c['line']) return self.shell.complete(c['text'], c['line'], cpos) def _at_shutdown(self): """Actions taken at shutdown by the kernel, called by python's atexit. """ # io.rprint("Kernel at_shutdown") # dbg if self._shutdown_message is not None: self.session.send(self.iopub_socket, self._shutdown_message, ident=self._topic('shutdown')) self.log.debug("%s", self._shutdown_message) [s.flush(zmq.POLLOUT) for s in self.shell_streams]
class BaseIPythonApplication(Application): name = Unicode(u'ipython') description = Unicode(u'IPython: an enhanced interactive Python shell.') version = Unicode(release.version) aliases = Dict(base_aliases) flags = Dict(base_flags) classes = List([ProfileDir]) # Track whether the config_file has changed, # because some logic happens only if we aren't using the default. config_file_specified = Set() config_file_name = Unicode() def _config_file_name_default(self): return self.name.replace('-', '_') + u'_config.py' def _config_file_name_changed(self, name, old, new): if new != old: self.config_file_specified.add(new) # The directory that contains IPython's builtin profiles. builtin_profile_dir = Unicode( os.path.join(get_ipython_package_dir(), u'config', u'profile', u'default')) config_file_paths = List(Unicode) def _config_file_paths_default(self): return [py3compat.getcwd()] extra_config_file = Unicode(config=True, help="""Path to an extra config file to load. If specified, load this config file in addition to any other IPython config. """) def _extra_config_file_changed(self, name, old, new): try: self.config_files.remove(old) except ValueError: pass self.config_file_specified.add(new) self.config_files.append(new) profile = Unicode(u'default', config=True, help="""The IPython profile to use.""") def _profile_changed(self, name, old, new): self.builtin_profile_dir = os.path.join(get_ipython_package_dir(), u'config', u'profile', new) ipython_dir = Unicode(config=True, help=""" The name of the IPython directory. This directory is used for logging configuration (through profiles), history storage, etc. The default is usually $HOME/.ipython. This options can also be specified through the environment variable IPYTHONDIR. """) def _ipython_dir_default(self): d = get_ipython_dir() self._ipython_dir_changed('ipython_dir', d, d) return d _in_init_profile_dir = False profile_dir = Instance(ProfileDir) def _profile_dir_default(self): # avoid recursion if self._in_init_profile_dir: return # profile_dir requested early, force initialization self.init_profile_dir() return self.profile_dir overwrite = Bool( False, config=True, help="""Whether to overwrite existing config files when copying""") auto_create = Bool( False, config=True, help="""Whether to create profile dir if it doesn't exist""") config_files = List(Unicode) def _config_files_default(self): return [self.config_file_name] copy_config_files = Bool( False, config=True, help="""Whether to install the default config files into the profile dir. If a new profile is being created, and IPython contains config files for that profile, then they will be staged into the new directory. Otherwise, default config files will be automatically generated. """) verbose_crash = Bool( False, config=True, help= """Create a massive crash report when IPython encounters what may be an internal error. The default is to append a short message to the usual traceback""") # The class to use as the crash handler. crash_handler_class = Type(crashhandler.CrashHandler) @catch_config_error def __init__(self, **kwargs): super(BaseIPythonApplication, self).__init__(**kwargs) # ensure current working directory exists try: directory = py3compat.getcwd() except: # raise exception self.log.error("Current working directory doesn't exist.") raise #------------------------------------------------------------------------- # Various stages of Application creation #------------------------------------------------------------------------- def init_crash_handler(self): """Create a crash handler, typically setting sys.excepthook to it.""" self.crash_handler = self.crash_handler_class(self) sys.excepthook = self.excepthook def unset_crashhandler(): sys.excepthook = sys.__excepthook__ atexit.register(unset_crashhandler) def excepthook(self, etype, evalue, tb): """this is sys.excepthook after init_crashhandler set self.verbose_crash=True to use our full crashhandler, instead of a regular traceback with a short message (crash_handler_lite) """ if self.verbose_crash: return self.crash_handler(etype, evalue, tb) else: return crashhandler.crash_handler_lite(etype, evalue, tb) def _ipython_dir_changed(self, name, old, new): if old in sys.path: sys.path.remove(old) sys.path.append(os.path.abspath(new)) if not os.path.isdir(new): os.makedirs(new, mode=0o777) readme = os.path.join(new, 'README') readme_src = os.path.join(get_ipython_package_dir(), u'config', u'profile', 'README') if not os.path.exists(readme) and os.path.exists(readme_src): shutil.copy(readme_src, readme) for d in ('extensions', 'nbextensions'): path = os.path.join(new, d) if not os.path.exists(path): try: os.mkdir(path) except OSError as e: if e.errno != errno.EEXIST: self.log.error("couldn't create path %s: %s", path, e) self.log.debug("IPYTHONDIR set to: %s" % new) def load_config_file(self, suppress_errors=True): """Load the config file. By default, errors in loading config are handled, and a warning printed on screen. For testing, the suppress_errors option is set to False, so errors will make tests fail. """ self.log.debug("Searching path %s for config files", self.config_file_paths) base_config = 'ipython_config.py' self.log.debug("Attempting to load config file: %s" % base_config) try: Application.load_config_file(self, base_config, path=self.config_file_paths) except ConfigFileNotFound: # ignore errors loading parent self.log.debug("Config file %s not found", base_config) pass for config_file_name in self.config_files: if not config_file_name or config_file_name == base_config: continue self.log.debug("Attempting to load config file: %s" % self.config_file_name) try: Application.load_config_file(self, config_file_name, path=self.config_file_paths) except ConfigFileNotFound: # Only warn if the default config file was NOT being used. if config_file_name in self.config_file_specified: msg = self.log.warn else: msg = self.log.debug msg("Config file not found, skipping: %s", config_file_name) except: # For testing purposes. if not suppress_errors: raise self.log.warn("Error loading config file: %s" % self.config_file_name, exc_info=True) def init_profile_dir(self): """initialize the profile dir""" self._in_init_profile_dir = True if self.profile_dir is not None: # already ran return if 'ProfileDir.location' not in self.config: # location not specified, find by profile name try: p = ProfileDir.find_profile_dir_by_name( self.ipython_dir, self.profile, self.config) except ProfileDirError: # not found, maybe create it (always create default profile) if self.auto_create or self.profile == 'default': try: p = ProfileDir.create_profile_dir_by_name( self.ipython_dir, self.profile, self.config) except ProfileDirError: self.log.fatal("Could not create profile: %r" % self.profile) self.exit(1) else: self.log.info("Created profile dir: %r" % p.location) else: self.log.fatal("Profile %r not found." % self.profile) self.exit(1) else: self.log.info("Using existing profile dir: %r" % p.location) else: location = self.config.ProfileDir.location # location is fully specified try: p = ProfileDir.find_profile_dir(location, self.config) except ProfileDirError: # not found, maybe create it if self.auto_create: try: p = ProfileDir.create_profile_dir( location, self.config) except ProfileDirError: self.log.fatal( "Could not create profile directory: %r" % location) self.exit(1) else: self.log.info("Creating new profile dir: %r" % location) else: self.log.fatal("Profile directory %r not found." % location) self.exit(1) else: self.log.info("Using existing profile dir: %r" % location) self.profile_dir = p self.config_file_paths.append(p.location) self._in_init_profile_dir = False def init_config_files(self): """[optionally] copy default config files into profile dir.""" # copy config files path = self.builtin_profile_dir if self.copy_config_files: src = self.profile cfg = self.config_file_name if path and os.path.exists(os.path.join(path, cfg)): self.log.warn( "Staging %r from %s into %r [overwrite=%s]" % (cfg, src, self.profile_dir.location, self.overwrite)) self.profile_dir.copy_config_file(cfg, path=path, overwrite=self.overwrite) else: self.stage_default_config_file() else: # Still stage *bundled* config files, but not generated ones # This is necessary for `ipython profile=sympy` to load the profile # on the first go files = glob.glob(os.path.join(path, '*.py')) for fullpath in files: cfg = os.path.basename(fullpath) if self.profile_dir.copy_config_file(cfg, path=path, overwrite=False): # file was copied self.log.warn( "Staging bundled %s from %s into %r" % (cfg, self.profile, self.profile_dir.location)) def stage_default_config_file(self): """auto generate default config file, and stage it into the profile.""" s = self.generate_config_file() fname = os.path.join(self.profile_dir.location, self.config_file_name) if self.overwrite or not os.path.exists(fname): self.log.warn("Generating default config file: %r" % (fname)) with open(fname, 'w') as f: f.write(s) @catch_config_error def initialize(self, argv=None): # don't hook up crash handler before parsing command-line self.parse_command_line(argv) self.init_crash_handler() if self.subapp is not None: # stop here if subapp is taking over return cl_config = self.config self.init_profile_dir() self.init_config_files() self.load_config_file() # enforce cl-opts override configfile opts: self.update_config(cl_config)
class Kernel(Configurable): #--------------------------------------------------------------------------- # Kernel interface #--------------------------------------------------------------------------- # attribute to override with a GUI eventloop = Any(None) def _eventloop_changed(self, name, old, new): """schedule call to eventloop from IOLoop""" loop = ioloop.IOLoop.instance() loop.add_callback(self.enter_eventloop) session = Instance(Session) profile_dir = Instance('IPython.core.profiledir.ProfileDir') shell_streams = List() control_stream = Instance(ZMQStream) iopub_socket = Instance(zmq.Socket) stdin_socket = Instance(zmq.Socket) log = Instance(logging.Logger) # identities: int_id = Integer(-1) ident = Unicode() def _ident_default(self): return unicode_type(uuid.uuid4()) # Private interface _darwin_app_nap = Bool(True, config=True, help="""Whether to use appnope for compatiblity with OS X App Nap. Only affects OS X >= 10.9. """ ) # track associations with current request _allow_stdin = Bool(False) _parent_header = Dict() _parent_ident = Any(b'') # Time to sleep after flushing the stdout/err buffers in each execute # cycle. While this introduces a hard limit on the minimal latency of the # execute cycle, it helps prevent output synchronization problems for # clients. # Units are in seconds. The minimum zmq latency on local host is probably # ~150 microseconds, set this to 500us for now. We may need to increase it # a little if it's not enough after more interactive testing. _execute_sleep = Float(0.0005, config=True) # Frequency of the kernel's event loop. # Units are in seconds, kernel subclasses for GUI toolkits may need to # adapt to milliseconds. _poll_interval = Float(0.05, config=True) # If the shutdown was requested over the network, we leave here the # necessary reply message so it can be sent by our registered atexit # handler. This ensures that the reply is only sent to clients truly at # the end of our shutdown process (which happens after the underlying # IPython shell's own shutdown). _shutdown_message = None # This is a dict of port number that the kernel is listening on. It is set # by record_ports and used by connect_request. _recorded_ports = Dict() # set of aborted msg_ids aborted = Set() # Track execution count here. For IPython, we override this to use the # execution count we store in the shell. execution_count = 0 def __init__(self, **kwargs): super(Kernel, self).__init__(**kwargs) # Build dict of handlers for message types msg_types = [ 'execute_request', 'complete_request', 'inspect_request', 'history_request', 'kernel_info_request', 'connect_request', 'shutdown_request', 'apply_request', ] self.shell_handlers = {} for msg_type in msg_types: self.shell_handlers[msg_type] = getattr(self, msg_type) control_msg_types = msg_types + [ 'clear_request', 'abort_request' ] self.control_handlers = {} for msg_type in control_msg_types: self.control_handlers[msg_type] = getattr(self, msg_type) def dispatch_control(self, msg): """dispatch control requests""" idents,msg = self.session.feed_identities(msg, copy=False) try: msg = self.session.unserialize(msg, content=True, copy=False) except: self.log.error("Invalid Control Message", exc_info=True) return self.log.debug("Control received: %s", msg) header = msg['header'] msg_type = header['msg_type'] handler = self.control_handlers.get(msg_type, None) if handler is None: self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type) else: try: handler(self.control_stream, idents, msg) except Exception: self.log.error("Exception in control handler:", exc_info=True) def dispatch_shell(self, stream, msg): """dispatch shell requests""" # flush control requests first if self.control_stream: self.control_stream.flush() idents,msg = self.session.feed_identities(msg, copy=False) try: msg = self.session.unserialize(msg, content=True, copy=False) except: self.log.error("Invalid Message", exc_info=True) return header = msg['header'] msg_id = header['msg_id'] msg_type = msg['header']['msg_type'] # Print some info about this message and leave a '--->' marker, so it's # easier to trace visually the message chain when debugging. Each # handler prints its message at the end. self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type) self.log.debug(' Content: %s\n --->\n ', msg['content']) if msg_id in self.aborted: self.aborted.remove(msg_id) # is it safe to assume a msg_id will not be resubmitted? reply_type = msg_type.split('_')[0] + '_reply' status = {'status' : 'aborted'} md = {'engine' : self.ident} md.update(status) self.session.send(stream, reply_type, metadata=md, content=status, parent=msg, ident=idents) return handler = self.shell_handlers.get(msg_type, None) if handler is None: self.log.error("UNKNOWN MESSAGE TYPE: %r", msg_type) else: # ensure default_int_handler during handler call sig = signal(SIGINT, default_int_handler) self.log.debug("%s: %s", msg_type, msg) try: handler(stream, idents, msg) except Exception: self.log.error("Exception in message handler:", exc_info=True) finally: signal(SIGINT, sig) def enter_eventloop(self): """enter eventloop""" self.log.info("entering eventloop %s", self.eventloop) for stream in self.shell_streams: # flush any pending replies, # which may be skipped by entering the eventloop stream.flush(zmq.POLLOUT) # restore default_int_handler signal(SIGINT, default_int_handler) while self.eventloop is not None: try: self.eventloop(self) except KeyboardInterrupt: # Ctrl-C shouldn't crash the kernel self.log.error("KeyboardInterrupt caught in kernel") continue else: # eventloop exited cleanly, this means we should stop (right?) self.eventloop = None break self.log.info("exiting eventloop") def start(self): """register dispatchers for streams""" if self.control_stream: self.control_stream.on_recv(self.dispatch_control, copy=False) def make_dispatcher(stream): def dispatcher(msg): return self.dispatch_shell(stream, msg) return dispatcher for s in self.shell_streams: s.on_recv(make_dispatcher(s), copy=False) # publish idle status self._publish_status('starting') def do_one_iteration(self): """step eventloop just once""" if self.control_stream: self.control_stream.flush() for stream in self.shell_streams: # handle at most one request per iteration stream.flush(zmq.POLLIN, 1) stream.flush(zmq.POLLOUT) def record_ports(self, ports): """Record the ports that this kernel is using. The creator of the Kernel instance must call this methods if they want the :meth:`connect_request` method to return the port numbers. """ self._recorded_ports = ports #--------------------------------------------------------------------------- # Kernel request handlers #--------------------------------------------------------------------------- def _make_metadata(self, other=None): """init metadata dict, for execute/apply_reply""" new_md = { 'dependencies_met' : True, 'engine' : self.ident, 'started': datetime.now(), } if other: new_md.update(other) return new_md def _publish_execute_input(self, code, parent, execution_count): """Publish the code request on the iopub stream.""" self.session.send(self.iopub_socket, u'execute_input', {u'code':code, u'execution_count': execution_count}, parent=parent, ident=self._topic('execute_input') ) def _publish_status(self, status, parent=None): """send status (busy/idle) on IOPub""" self.session.send(self.iopub_socket, u'status', {u'execution_state': status}, parent=parent, ident=self._topic('status'), ) def set_parent(self, ident, parent): """Set the current parent_header Side effects (IOPub messages) and replies are associated with the request that caused them via the parent_header. The parent identity is used to route input_request messages on the stdin channel. """ self._parent_ident = ident self._parent_header = parent def send_response(self, stream, msg_or_type, content=None, ident=None, buffers=None, track=False, header=None, metadata=None): """Send a response to the message we're currently processing. This accepts all the parameters of :meth:`IPython.kernel.zmq.session.Session.send` except ``parent``. This relies on :meth:`set_parent` having been called for the current message. """ return self.session.send(stream, msg_or_type, content, self._parent_header, ident, buffers, track, header, metadata) def execute_request(self, stream, ident, parent): """handle an execute_request""" self._publish_status(u'busy', parent) try: content = parent[u'content'] code = py3compat.cast_unicode_py2(content[u'code']) silent = content[u'silent'] store_history = content.get(u'store_history', not silent) user_expressions = content.get('user_expressions', {}) allow_stdin = content.get('allow_stdin', False) except: self.log.error("Got bad msg: ") self.log.error("%s", parent) return md = self._make_metadata(parent['metadata']) # Set the parent message of the display hook and out streams. self.set_parent(ident, parent) # Re-broadcast our input for the benefit of listening clients, and # start computing output if not silent: self.execution_count += 1 self._publish_execute_input(code, parent, self.execution_count) reply_content = self.do_execute(code, silent, store_history, user_expressions, allow_stdin) # Flush output before sending the reply. sys.stdout.flush() sys.stderr.flush() # FIXME: on rare occasions, the flush doesn't seem to make it to the # clients... This seems to mitigate the problem, but we definitely need # to better understand what's going on. if self._execute_sleep: time.sleep(self._execute_sleep) # Send the reply. reply_content = json_clean(reply_content) md['status'] = reply_content['status'] if reply_content['status'] == 'error' and \ reply_content['ename'] == 'UnmetDependency': md['dependencies_met'] = False reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent, metadata=md, ident=ident) self.log.debug("%s", reply_msg) if not silent and reply_msg['content']['status'] == u'error': self._abort_queues() self._publish_status(u'idle', parent) def do_execute(self, code, silent, store_history=True, user_experssions=None, allow_stdin=False): """Execute user code. Must be overridden by subclasses. """ raise NotImplementedError def complete_request(self, stream, ident, parent): content = parent['content'] code = content['code'] cursor_pos = content['cursor_pos'] matches = self.do_complete(code, cursor_pos) matches = json_clean(matches) completion_msg = self.session.send(stream, 'complete_reply', matches, parent, ident) self.log.debug("%s", completion_msg) def do_complete(self, code, cursor_pos): """Override in subclasses to find completions. """ return {'matches' : [], 'cursor_end' : cursor_pos, 'cursor_start' : cursor_pos, 'metadata' : {}, 'status' : 'ok'} def inspect_request(self, stream, ident, parent): content = parent['content'] reply_content = self.do_inspect(content['code'], content['cursor_pos'], content.get('detail_level', 0)) # Before we send this object over, we scrub it for JSON usage reply_content = json_clean(reply_content) msg = self.session.send(stream, 'inspect_reply', reply_content, parent, ident) self.log.debug("%s", msg) def do_inspect(self, code, cursor_pos, detail_level=0): """Override in subclasses to allow introspection. """ return {'status': 'ok', 'data':{}, 'metadata':{}, 'found':False} def history_request(self, stream, ident, parent): content = parent['content'] reply_content = self.do_history(**content) reply_content = json_clean(reply_content) msg = self.session.send(stream, 'history_reply', reply_content, parent, ident) self.log.debug("%s", msg) def do_history(self, hist_access_type, output, raw, session=None, start=None, stop=None, n=None, pattern=None, unique=False): """Override in subclasses to access history. """ return {'history': []} def connect_request(self, stream, ident, parent): if self._recorded_ports is not None: content = self._recorded_ports.copy() else: content = {} msg = self.session.send(stream, 'connect_reply', content, parent, ident) self.log.debug("%s", msg) @property def kernel_info(self): return { 'protocol_version': release.kernel_protocol_version, 'implementation': self.implementation, 'implementation_version': self.implementation_version, 'language': self.language, 'language_version': self.language_version, 'banner': self.banner, } def kernel_info_request(self, stream, ident, parent): msg = self.session.send(stream, 'kernel_info_reply', self.kernel_info, parent, ident) self.log.debug("%s", msg) def shutdown_request(self, stream, ident, parent): content = self.do_shutdown(parent['content']['restart']) self.session.send(stream, u'shutdown_reply', content, parent, ident=ident) # same content, but different msg_id for broadcasting on IOPub self._shutdown_message = self.session.msg(u'shutdown_reply', content, parent ) self._at_shutdown() # call sys.exit after a short delay loop = ioloop.IOLoop.instance() loop.add_timeout(time.time()+0.1, loop.stop) def do_shutdown(self, restart): """Override in subclasses to do things when the frontend shuts down the kernel. """ return {'status': 'ok', 'restart': restart} #--------------------------------------------------------------------------- # Engine methods #--------------------------------------------------------------------------- def apply_request(self, stream, ident, parent): try: content = parent[u'content'] bufs = parent[u'buffers'] msg_id = parent['header']['msg_id'] except: self.log.error("Got bad msg: %s", parent, exc_info=True) return self._publish_status(u'busy', parent) # Set the parent message of the display hook and out streams. self.set_parent(ident, parent) md = self._make_metadata(parent['metadata']) reply_content, result_buf = self.do_apply(content, bufs, msg_id, md) # put 'ok'/'error' status in header, for scheduler introspection: md['status'] = reply_content['status'] # flush i/o sys.stdout.flush() sys.stderr.flush() self.session.send(stream, u'apply_reply', reply_content, parent=parent, ident=ident,buffers=result_buf, metadata=md) self._publish_status(u'idle', parent) def do_apply(self, content, bufs, msg_id, reply_metadata): """Override in subclasses to support the IPython parallel framework. """ raise NotImplementedError #--------------------------------------------------------------------------- # Control messages #--------------------------------------------------------------------------- def abort_request(self, stream, ident, parent): """abort a specifig msg by id""" msg_ids = parent['content'].get('msg_ids', None) if isinstance(msg_ids, string_types): msg_ids = [msg_ids] if not msg_ids: self.abort_queues() for mid in msg_ids: self.aborted.add(str(mid)) content = dict(status='ok') reply_msg = self.session.send(stream, 'abort_reply', content=content, parent=parent, ident=ident) self.log.debug("%s", reply_msg) def clear_request(self, stream, idents, parent): """Clear our namespace.""" content = self.do_clear() self.session.send(stream, 'clear_reply', ident=idents, parent=parent, content = content) def do_clear(self): """Override in subclasses to clear the namespace This is only required for IPython.parallel. """ raise NotImplementedError #--------------------------------------------------------------------------- # Protected interface #--------------------------------------------------------------------------- def _topic(self, topic): """prefixed topic for IOPub messages""" if self.int_id >= 0: base = "engine.%i" % self.int_id else: base = "kernel.%s" % self.ident return py3compat.cast_bytes("%s.%s" % (base, topic)) def _abort_queues(self): for stream in self.shell_streams: if stream: self._abort_queue(stream) def _abort_queue(self, stream): poller = zmq.Poller() poller.register(stream.socket, zmq.POLLIN) while True: idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True) if msg is None: return self.log.info("Aborting:") self.log.info("%s", msg) msg_type = msg['header']['msg_type'] reply_type = msg_type.split('_')[0] + '_reply' status = {'status' : 'aborted'} md = {'engine' : self.ident} md.update(status) reply_msg = self.session.send(stream, reply_type, metadata=md, content=status, parent=msg, ident=idents) self.log.debug("%s", reply_msg) # We need to wait a bit for requests to come in. This can probably # be set shorter for true asynchronous clients. poller.poll(50) def _no_raw_input(self): """Raise StdinNotImplentedError if active frontend doesn't support stdin.""" raise StdinNotImplementedError("raw_input was called, but this " "frontend does not support stdin.") def getpass(self, prompt=''): """Forward getpass to frontends Raises ------ StdinNotImplentedError if active frontend doesn't support stdin. """ if not self._allow_stdin: raise StdinNotImplementedError( "getpass was called, but this frontend does not support input requests." ) return self._input_request(prompt, self._parent_ident, self._parent_header, password=True, ) def raw_input(self, prompt=''): """Forward raw_input to frontends Raises ------ StdinNotImplentedError if active frontend doesn't support stdin. """ if not self._allow_stdin: raise StdinNotImplementedError( "raw_input was called, but this frontend does not support input requests." ) return self._input_request(prompt, self._parent_ident, self._parent_header, password=False, ) def _input_request(self, prompt, ident, parent, password=False): # Flush output before making the request. sys.stderr.flush() sys.stdout.flush() # flush the stdin socket, to purge stale replies while True: try: self.stdin_socket.recv_multipart(zmq.NOBLOCK) except zmq.ZMQError as e: if e.errno == zmq.EAGAIN: break else: raise # Send the input request. content = json_clean(dict(prompt=prompt, password=password)) self.session.send(self.stdin_socket, u'input_request', content, parent, ident=ident) # Await a response. while True: try: ident, reply = self.session.recv(self.stdin_socket, 0) except Exception: self.log.warn("Invalid Message:", exc_info=True) except KeyboardInterrupt: # re-raise KeyboardInterrupt, to truncate traceback raise KeyboardInterrupt else: break try: value = py3compat.unicode_to_str(reply['content']['value']) except: self.log.error("Bad input_reply: %s", parent) value = '' if value == '\x04': # EOF raise EOFError return value def _at_shutdown(self): """Actions taken at shutdown by the kernel, called by python's atexit. """ # io.rprint("Kernel at_shutdown") # dbg if self._shutdown_message is not None: self.session.send(self.iopub_socket, self._shutdown_message, ident=self._topic('shutdown')) self.log.debug("%s", self._shutdown_message) [ s.flush(zmq.POLLOUT) for s in self.shell_streams ]
class HeartMonitor(LoggingConfigurable): """A basic HeartMonitor class pingstream: a PUB stream pongstream: an ROUTER stream period: the period of the heartbeat in milliseconds""" period = Integer( 3000, config=True, help='The frequency at which the Hub pings the engines for heartbeats ' '(in ms)', ) max_heartmonitor_misses = Integer( 10, config=True, help= 'Allowed consecutive missed pings from controller Hub to engine before unregistering.', ) pingstream = Instance('zmq.eventloop.zmqstream.ZMQStream') pongstream = Instance('zmq.eventloop.zmqstream.ZMQStream') loop = Instance('zmq.eventloop.ioloop.IOLoop') def _loop_default(self): return ioloop.IOLoop.instance() # not settable: hearts = Set() responses = Set() on_probation = Dict() last_ping = CFloat(0) _new_handlers = Set() _failure_handlers = Set() lifetime = CFloat(0) tic = CFloat(0) def __init__(self, **kwargs): super(HeartMonitor, self).__init__(**kwargs) self.pongstream.on_recv(self.handle_pong) def start(self): self.tic = time.time() self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop) self.caller.start() def add_new_heart_handler(self, handler): """add a new handler for new hearts""" self.log.debug("heartbeat::new_heart_handler: %s", handler) self._new_handlers.add(handler) def add_heart_failure_handler(self, handler): """add a new handler for heart failure""" self.log.debug("heartbeat::new heart failure handler: %s", handler) self._failure_handlers.add(handler) def beat(self): self.pongstream.flush() self.last_ping = self.lifetime toc = time.time() self.lifetime += toc - self.tic self.tic = toc self.log.debug("heartbeat::sending %s", self.lifetime) goodhearts = self.hearts.intersection(self.responses) missed_beats = self.hearts.difference(goodhearts) newhearts = self.responses.difference(goodhearts) for heart in newhearts: self.handle_new_heart(heart) heartfailures, on_probation = self._check_missed( missed_beats, self.on_probation, self.hearts) for failure in heartfailures: self.handle_heart_failure(failure) self.on_probation = on_probation self.responses = set() #print self.on_probation, self.hearts # self.log.debug("heartbeat::beat %.3f, %i beating hearts", self.lifetime, len(self.hearts)) self.pingstream.send(str_to_bytes(str(self.lifetime))) # flush stream to force immediate socket send self.pingstream.flush() def _check_missed(self, missed_beats, on_probation, hearts): """Update heartbeats on probation, identifying any that have too many misses. """ failures = [] new_probation = {} for cur_heart in (b for b in missed_beats if b in hearts): miss_count = on_probation.get(cur_heart, 0) + 1 self.log.info("heartbeat::missed %s : %s" % (cur_heart, miss_count)) if miss_count > self.max_heartmonitor_misses: failures.append(cur_heart) else: new_probation[cur_heart] = miss_count return failures, new_probation def handle_new_heart(self, heart): if self._new_handlers: for handler in self._new_handlers: handler(heart) else: self.log.info("heartbeat::yay, got new heart %s!", heart) self.hearts.add(heart) def handle_heart_failure(self, heart): if self._failure_handlers: for handler in self._failure_handlers: try: handler(heart) except Exception as e: self.log.error("heartbeat::Bad Handler! %s", handler, exc_info=True) pass else: self.log.info("heartbeat::Heart %s failed :(", heart) self.hearts.remove(heart) @log_errors def handle_pong(self, msg): "a heart just beat" current = str_to_bytes(str(self.lifetime)) last = str_to_bytes(str(self.last_ping)) if msg[1] == current: delta = time.time() - self.tic # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta)) self.responses.add(msg[0]) elif msg[1] == last: delta = time.time() - self.tic + (self.lifetime - self.last_ping) self.log.warn( "heartbeat::heart %r missed a beat, and took %.2f ms to respond", msg[0], 1000 * delta) self.responses.add(msg[0]) else: self.log.warn( "heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)", msg[1], self.lifetime)
class KernelSpecManager(Configurable): ipython_dir = Unicode() def _ipython_dir_default(self): return get_ipython_dir() user_kernel_dir = Unicode() def _user_kernel_dir_default(self): return pjoin(self.ipython_dir, 'kernels') @property def env_kernel_dir(self): return pjoin(sys.prefix, 'share', 'jupyter', 'kernels') whitelist = Set(config=True, help="""Whitelist of allowed kernel names. By default, all installed kernels are allowed. """) kernel_dirs = List( help= "List of kernel directories to search. Later ones take priority over earlier." ) def _kernel_dirs_default(self): dirs = SYSTEM_KERNEL_DIRS[:] if self.env_kernel_dir not in dirs: dirs.append(self.env_kernel_dir) dirs.append(self.user_kernel_dir) return dirs def find_kernel_specs(self): """Returns a dict mapping kernel names to resource directories.""" d = {} for kernel_dir in self.kernel_dirs: d.update(_list_kernels_in(kernel_dir)) if self.whitelist: # filter if there's a whitelist d = { name: spec for name, spec in d.items() if name in self.whitelist } return d # TODO: Caching? def get_kernel_spec(self, kernel_name): """Returns a :class:`KernelSpec` instance for the given kernel_name. Raises :exc:`NoSuchKernel` if the given kernel name is not found. """ d = self.find_kernel_specs() try: resource_dir = d[kernel_name.lower()] except KeyError: raise NoSuchKernel(kernel_name) return KernelSpec.from_resource_dir(resource_dir) def _get_destination_dir(self, kernel_name, user=False): if user: return os.path.join(self.user_kernel_dir, kernel_name) else: if SYSTEM_KERNEL_DIRS: return os.path.join(SYSTEM_KERNEL_DIRS[-1], kernel_name) else: raise EnvironmentError( "No system kernel directory is available") def install_kernel_spec(self, source_dir, kernel_name=None, user=False, replace=False): """Install a kernel spec by copying its directory. If ``kernel_name`` is not given, the basename of ``source_dir`` will be used. If ``user`` is False, it will attempt to install into the systemwide kernel registry. If the process does not have appropriate permissions, an :exc:`OSError` will be raised. If ``replace`` is True, this will replace an existing kernel of the same name. Otherwise, if the destination already exists, an :exc:`OSError` will be raised. """ if not kernel_name: kernel_name = os.path.basename(source_dir) kernel_name = kernel_name.lower() destination = self._get_destination_dir(kernel_name, user=user) if replace and os.path.isdir(destination): shutil.rmtree(destination) shutil.copytree(source_dir, destination) def install_native_kernel_spec(self, user=False): """DEPRECATED: Use ipython_kernel.kenelspec.install""" warnings.warn("install_native_kernel_spec is deprecated." " Use ipython_kernel.kernelspec import install.") from ipython_kernel.kernelspec import install install(self, user=user)
class InlineBackend(InlineBackendConfig): """An object to store configuration of the inline backend.""" def _config_changed(self, name, old, new): # warn on change of renamed config section if new.InlineBackendConfig != old.InlineBackendConfig: warn("InlineBackendConfig has been renamed to InlineBackend") super(InlineBackend, self)._config_changed(name, old, new) # The typical default figure size is too large for inline use, # so we shrink the figure size to 6x4, and tweak fonts to # make that fit. rc = Dict( { 'figure.figsize': (6.0, 4.0), # play nicely with white background in the Qt and notebook frontend 'figure.facecolor': (1, 1, 1, 0), 'figure.edgecolor': (1, 1, 1, 0), # 12pt labels get cutoff on 6x4 logplots, so use 10pt. 'font.size': 10, # 72 dpi matches SVG/qtconsole # this only affects PNG export, as SVG has no dpi setting 'savefig.dpi': 72, # 10pt still needs a little more room on the xlabel: 'figure.subplot.bottom': .125 }, config=True, help="""Subset of matplotlib rcParams that should be different for the inline backend.""") figure_formats = Set({'png'}, config=True, help="""A set of figure formats to enable: 'png', 'retina', 'jpeg', 'svg', 'pdf'.""") def _figure_formats_changed(self, name, old, new): from IPython.core.pylabtools import select_figure_formats if 'jpg' in new or 'jpeg' in new: if not pil_available(): raise TraitError("Requires PIL/Pillow for JPG figures") if self.shell is None: return else: select_figure_formats(self.shell, new) figure_format = Unicode(config=True, help="""The figure format to enable (deprecated use `figure_formats` instead)""") def _figure_format_changed(self, name, old, new): if new: self.figure_formats = {new} quality = Int( default_value=90, config=True, help="Quality of compression [10-100], currently for lossy JPEG only.") def _quality_changed(self, name, old, new): if new < 10 or new > 100: raise TraitError("figure JPEG quality must be in [10-100] range.") close_figures = Bool(True, config=True, help="""Close all figures at the end of each cell. When True, ensures that each cell starts with no active figures, but it also means that one must keep track of references in order to edit or redraw figures in subsequent cells. This mode is ideal for the notebook, where residual plots from other cells might be surprising. When False, one must call figure() to create new figures. This means that gcf() and getfigs() can reference figures created in other cells, and the active figure can continue to be edited with pylab/pyplot methods that reference the current active figure. This mode facilitates iterative editing of figures, and behaves most consistently with other matplotlib backends, but figure barriers between cells must be explicit. """) shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')