예제 #1
0
class Authenticator(LoggingConfigurable):
    """A class for authentication.
    
    The API is one method, `authenticate`, a tornado gen.coroutine.
    """
    
    db = Any()
    whitelist = Set(config=True,
        help="""Username whitelist.
        
        Use this to restrict which users can login.
        If empty, allow any user to attempt login.
        """
    )
    custom_html = Unicode('')
    
    @gen.coroutine
    def authenticate(self, handler, data):
        """Authenticate a user with login form data.
        
        This must be a tornado gen.coroutine.
        It must return the username on successful authentication,
        and return None on failed authentication.
        """
    
    def add_user(self, user):
        """Add a new user
        
        By default, this just adds the user to the whitelist.
        
        Subclasses may do more extensive things,
        such as adding actual unix users.
        """
        if self.whitelist:
            self.whitelist.add(user.name)
    
    def delete_user(self, user):
        """Triggered when a user is deleted.
        
        Removes the user from the whitelist.
        """
        if user.name in self.whitelist:
            self.whitelist.remove(user.name)
    
    def login_url(self, base_url):
        """Override to register a custom login handler"""
        return url_path_join(base_url, 'login')
    
    def logout_url(self, base_url):
        """Override to register a custom logout handler"""
        return url_path_join(base_url, 'logout')
    
    def get_handlers(self, app):
        """Return any custom handlers the authenticator needs to register
        
        (e.g. for OAuth)
        """
        return [
            ('/login', LoginHandler),
        ]
예제 #2
0
class Containers(Configurable):
    lis = List(config=True)

    def _lis_default(self):
        return [-1]

    s = Set(config=True)

    def _s_default(self):
        return {'a'}

    d = Dict(config=True)

    def _d_default(self):
        return {'a': 'b'}
예제 #3
0
파일: hub.py 프로젝트: sherjilozair/ipython
class EngineConnector(HasTraits):
    """A simple object for accessing the various zmq connections of an object.
    Attributes are:
    id (int): engine ID
    uuid (str): uuid (unused?)
    queue (str): identity of queue's XREQ socket
    registration (str): identity of registration XREQ socket
    heartbeat (str): identity of heartbeat XREQ socket
    """
    id = Int(0)
    queue = Str()
    control = Str()
    registration = Str()
    heartbeat = Str()
    pending = Set()
예제 #4
0
class HeartMonitor(LoggingConfigurable):
    """A basic HeartMonitor class
    pingstream: a PUB stream
    pongstream: an XREP stream
    period: the period of the heartbeat in milliseconds"""

    period = Integer(3000, config=True,
        help='The frequency at which the Hub pings the engines for heartbeats '
        '(in ms)',
    )

    pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
    pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
    loop = Instance('zmq.eventloop.ioloop.IOLoop')
    def _loop_default(self):
        return ioloop.IOLoop.instance()

    # not settable:
    hearts=Set()
    responses=Set()
    on_probation=Set()
    last_ping=CFloat(0)
    _new_handlers = Set()
    _failure_handlers = Set()
    lifetime = CFloat(0)
    tic = CFloat(0)

    def __init__(self, **kwargs):
        super(HeartMonitor, self).__init__(**kwargs)

        self.pongstream.on_recv(self.handle_pong)

    def start(self):
        self.tic = time.time()
        self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
        self.caller.start()

    def add_new_heart_handler(self, handler):
        """add a new handler for new hearts"""
        self.log.debug("heartbeat::new_heart_handler: %s", handler)
        self._new_handlers.add(handler)

    def add_heart_failure_handler(self, handler):
        """add a new handler for heart failure"""
        self.log.debug("heartbeat::new heart failure handler: %s", handler)
        self._failure_handlers.add(handler)

    def beat(self):
        self.pongstream.flush()
        self.last_ping = self.lifetime

        toc = time.time()
        self.lifetime += toc-self.tic
        self.tic = toc
        self.log.debug("heartbeat::sending %s", self.lifetime)
        goodhearts = self.hearts.intersection(self.responses)
        missed_beats = self.hearts.difference(goodhearts)
        heartfailures = self.on_probation.intersection(missed_beats)
        newhearts = self.responses.difference(goodhearts)
        map(self.handle_new_heart, newhearts)
        map(self.handle_heart_failure, heartfailures)
        self.on_probation = missed_beats.intersection(self.hearts)
        self.responses = set()
        # print self.on_probation, self.hearts
        # self.log.debug("heartbeat::beat %.3f, %i beating hearts", self.lifetime, len(self.hearts))
        self.pingstream.send(asbytes(str(self.lifetime)))
        # flush stream to force immediate socket send
        self.pingstream.flush()

    def handle_new_heart(self, heart):
        if self._new_handlers:
            for handler in self._new_handlers:
                handler(heart)
        else:
            self.log.info("heartbeat::yay, got new heart %s!", heart)
        self.hearts.add(heart)

    def handle_heart_failure(self, heart):
        if self._failure_handlers:
            for handler in self._failure_handlers:
                try:
                    handler(heart)
                except Exception as e:
                    self.log.error("heartbeat::Bad Handler! %s", handler, exc_info=True)
                    pass
        else:
            self.log.info("heartbeat::Heart %s failed :(", heart)
        self.hearts.remove(heart)


    @log_errors
    def handle_pong(self, msg):
        "a heart just beat"
        current = asbytes(str(self.lifetime))
        last = asbytes(str(self.last_ping))
        if msg[1] == current:
            delta = time.time()-self.tic
            # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
            self.responses.add(msg[0])
        elif msg[1] == last:
            delta = time.time()-self.tic + (self.lifetime-self.last_ping)
            self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond", msg[0], 1000*delta)
            self.responses.add(msg[0])
        else:
            self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)", msg[1], self.lifetime)
예제 #5
0
class TaskScheduler(SessionFactory):
    """Python TaskScheduler object.

    This is the simplest object that supports msg_id based
    DAG dependencies. *Only* task msg_ids are checked, not
    msg_ids of jobs submitted via the MUX queue.

    """

    hwm = Integer(1,
                  config=True,
                  help="""specify the High Water Mark (HWM) for the downstream
        socket in the Task scheduler. This is the maximum number
        of allowed outstanding tasks on each engine.
        
        The default (1) means that only one task can be outstanding on each
        engine.  Setting TaskScheduler.hwm=0 means there is no limit, and the
        engines continue to be assigned tasks while they are working,
        effectively hiding network latency behind computation, but can result
        in an imbalance of work when submitting many heterogenous tasks all at
        once.  Any positive value greater than one is a compromise between the
        two.

        """)
    scheme_name = Enum(
        ('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
        'leastload',
        config=True,
        allow_none=False,
        help="""select the task scheduler scheme  [default: Python LRU]
        Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
    )

    def _scheme_name_changed(self, old, new):
        self.log.debug("Using scheme %r" % new)
        self.scheme = globals()[new]

    # input arguments:
    scheme = Instance(FunctionType)  # function for determining the destination

    def _scheme_default(self):
        return leastload

    client_stream = Instance(zmqstream.ZMQStream)  # client-facing stream
    engine_stream = Instance(zmqstream.ZMQStream)  # engine-facing stream
    notifier_stream = Instance(zmqstream.ZMQStream)  # hub-facing sub stream
    mon_stream = Instance(zmqstream.ZMQStream)  # hub-facing pub stream
    query_stream = Instance(zmqstream.ZMQStream)  # hub-facing DEALER stream

    # internals:
    graph = Dict()  # dict by msg_id of [ msg_ids that depend on key ]
    retries = Dict()  # dict by msg_id of retries remaining (non-neg ints)
    # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
    depending = Dict()  # dict by msg_id of Jobs
    pending = Dict()  # dict by engine_uuid of submitted tasks
    completed = Dict()  # dict by engine_uuid of completed tasks
    failed = Dict()  # dict by engine_uuid of failed tasks
    destinations = Dict(
    )  # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
    clients = Dict()  # dict by msg_id for who submitted the task
    targets = List()  # list of target IDENTs
    loads = List()  # list of engine loads
    # full = Set() # set of IDENTs that have HWM outstanding tasks
    all_completed = Set()  # set of all completed tasks
    all_failed = Set()  # set of all failed tasks
    all_done = Set()  # set of all finished tasks=union(completed,failed)
    all_ids = Set()  # set of all submitted task IDs

    auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')

    ident = CBytes()  # ZMQ identity. This should just be self.session.session

    # but ensure Bytes
    def _ident_default(self):
        return self.session.bsession

    def start(self):
        self.query_stream.on_recv(self.dispatch_query_reply)
        self.session.send(self.query_stream, "connection_request", {})

        self.engine_stream.on_recv(self.dispatch_result, copy=False)
        self.client_stream.on_recv(self.dispatch_submission, copy=False)

        self._notification_handlers = dict(
            registration_notification=self._register_engine,
            unregistration_notification=self._unregister_engine)
        self.notifier_stream.on_recv(self.dispatch_notification)
        self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3,
                                               self.loop)  # 1 Hz
        self.auditor.start()
        self.log.info("Scheduler started [%s]" % self.scheme_name)

    def resume_receiving(self):
        """Resume accepting jobs."""
        self.client_stream.on_recv(self.dispatch_submission, copy=False)

    def stop_receiving(self):
        """Stop accepting jobs while there are no engines.
        Leave them in the ZMQ queue."""
        self.client_stream.on_recv(None)

    #-----------------------------------------------------------------------
    # [Un]Registration Handling
    #-----------------------------------------------------------------------

    def dispatch_query_reply(self, msg):
        """handle reply to our initial connection request"""
        try:
            idents, msg = self.session.feed_identities(msg)
        except ValueError:
            self.log.warn("task::Invalid Message: %r", msg)
            return
        try:
            msg = self.session.unserialize(msg)
        except ValueError:
            self.log.warn("task::Unauthorized message from: %r" % idents)
            return

        content = msg['content']
        for uuid in content.get('engines', {}).values():
            self._register_engine(cast_bytes(uuid))

    @util.log_errors
    def dispatch_notification(self, msg):
        """dispatch register/unregister events."""
        try:
            idents, msg = self.session.feed_identities(msg)
        except ValueError:
            self.log.warn("task::Invalid Message: %r", msg)
            return
        try:
            msg = self.session.unserialize(msg)
        except ValueError:
            self.log.warn("task::Unauthorized message from: %r" % idents)
            return

        msg_type = msg['header']['msg_type']

        handler = self._notification_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("Unhandled message type: %r" % msg_type)
        else:
            try:
                handler(cast_bytes(msg['content']['uuid']))
            except Exception:
                self.log.error("task::Invalid notification msg: %r",
                               msg,
                               exc_info=True)

    def _register_engine(self, uid):
        """New engine with ident `uid` became available."""
        # head of the line:
        self.targets.insert(0, uid)
        self.loads.insert(0, 0)

        # initialize sets
        self.completed[uid] = set()
        self.failed[uid] = set()
        self.pending[uid] = {}

        # rescan the graph:
        self.update_graph(None)

    def _unregister_engine(self, uid):
        """Existing engine with ident `uid` became unavailable."""
        if len(self.targets) == 1:
            # this was our only engine
            pass

        # handle any potentially finished tasks:
        self.engine_stream.flush()

        # don't pop destinations, because they might be used later
        # map(self.destinations.pop, self.completed.pop(uid))
        # map(self.destinations.pop, self.failed.pop(uid))

        # prevent this engine from receiving work
        idx = self.targets.index(uid)
        self.targets.pop(idx)
        self.loads.pop(idx)

        # wait 5 seconds before cleaning up pending jobs, since the results might
        # still be incoming
        if self.pending[uid]:
            dc = ioloop.DelayedCallback(
                lambda: self.handle_stranded_tasks(uid), 5000, self.loop)
            dc.start()
        else:
            self.completed.pop(uid)
            self.failed.pop(uid)

    def handle_stranded_tasks(self, engine):
        """Deal with jobs resident in an engine that died."""
        lost = self.pending[engine]
        for msg_id in lost.keys():
            if msg_id not in self.pending[engine]:
                # prevent double-handling of messages
                continue

            raw_msg = lost[msg_id].raw_msg
            idents, msg = self.session.feed_identities(raw_msg, copy=False)
            parent = self.session.unpack(msg[1].bytes)
            idents = [engine, idents[0]]

            # build fake error reply
            try:
                raise error.EngineError(
                    "Engine %r died while running task %r" % (engine, msg_id))
            except:
                content = error.wrap_exception()
            # build fake metadata
            md = dict(
                status=u'error',
                engine=engine,
                date=datetime.now(),
            )
            msg = self.session.msg('apply_reply',
                                   content,
                                   parent=parent,
                                   metadata=md)
            raw_reply = map(zmq.Message,
                            self.session.serialize(msg, ident=idents))
            # and dispatch it
            self.dispatch_result(raw_reply)

        # finally scrub completed/failed lists
        self.completed.pop(engine)
        self.failed.pop(engine)

    #-----------------------------------------------------------------------
    # Job Submission
    #-----------------------------------------------------------------------

    @util.log_errors
    def dispatch_submission(self, raw_msg):
        """Dispatch job submission to appropriate handlers."""
        # ensure targets up to date:
        self.notifier_stream.flush()
        try:
            idents, msg = self.session.feed_identities(raw_msg, copy=False)
            msg = self.session.unserialize(msg, content=False, copy=False)
        except Exception:
            self.log.error("task::Invaid task msg: %r" % raw_msg,
                           exc_info=True)
            return

        # send to monitor
        self.mon_stream.send_multipart([b'intask'] + raw_msg, copy=False)

        header = msg['header']
        md = msg['metadata']
        msg_id = header['msg_id']
        self.all_ids.add(msg_id)

        # get targets as a set of bytes objects
        # from a list of unicode objects
        targets = md.get('targets', [])
        targets = map(cast_bytes, targets)
        targets = set(targets)

        retries = md.get('retries', 0)
        self.retries[msg_id] = retries

        # time dependencies
        after = md.get('after', None)
        if after:
            after = Dependency(after)
            if after.all:
                if after.success:
                    after = Dependency(
                        after.difference(self.all_completed),
                        success=after.success,
                        failure=after.failure,
                        all=after.all,
                    )
                if after.failure:
                    after = Dependency(
                        after.difference(self.all_failed),
                        success=after.success,
                        failure=after.failure,
                        all=after.all,
                    )
            if after.check(self.all_completed, self.all_failed):
                # recast as empty set, if `after` already met,
                # to prevent unnecessary set comparisons
                after = MET
        else:
            after = MET

        # location dependencies
        follow = Dependency(md.get('follow', []))

        # turn timeouts into datetime objects:
        timeout = md.get('timeout', None)
        if timeout:
            # cast to float, because jsonlib returns floats as decimal.Decimal,
            # which timedelta does not accept
            timeout = datetime.now() + timedelta(0, float(timeout), 0)

        job = Job(
            msg_id=msg_id,
            raw_msg=raw_msg,
            idents=idents,
            msg=msg,
            header=header,
            targets=targets,
            after=after,
            follow=follow,
            timeout=timeout,
            metadata=md,
        )

        # validate and reduce dependencies:
        for dep in after, follow:
            if not dep:  # empty dependency
                continue
            # check valid:
            if msg_id in dep or dep.difference(self.all_ids):
                self.depending[msg_id] = job
                return self.fail_unreachable(msg_id, error.InvalidDependency)
            # check if unreachable:
            if dep.unreachable(self.all_completed, self.all_failed):
                self.depending[msg_id] = job
                return self.fail_unreachable(msg_id)

        if after.check(self.all_completed, self.all_failed):
            # time deps already met, try to run
            if not self.maybe_run(job):
                # can't run yet
                if msg_id not in self.all_failed:
                    # could have failed as unreachable
                    self.save_unmet(job)
        else:
            self.save_unmet(job)

    def audit_timeouts(self):
        """Audit all waiting tasks for expired timeouts."""
        now = datetime.now()
        for msg_id in self.depending.keys():
            # must recheck, in case one failure cascaded to another:
            if msg_id in self.depending:
                job = self.depending[msg_id]
                if job.timeout and job.timeout < now:
                    self.fail_unreachable(msg_id, error.TaskTimeout)

    def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
        """a task has become unreachable, send a reply with an ImpossibleDependency
        error."""
        if msg_id not in self.depending:
            self.log.error("msg %r already failed!", msg_id)
            return
        job = self.depending.pop(msg_id)
        for mid in job.dependents:
            if mid in self.graph:
                self.graph[mid].remove(msg_id)

        try:
            raise why()
        except:
            content = error.wrap_exception()

        self.all_done.add(msg_id)
        self.all_failed.add(msg_id)

        msg = self.session.send(self.client_stream,
                                'apply_reply',
                                content,
                                parent=job.header,
                                ident=job.idents)
        self.session.send(self.mon_stream,
                          msg,
                          ident=[b'outtask'] + job.idents)

        self.update_graph(msg_id, success=False)

    def maybe_run(self, job):
        """check location dependencies, and run if they are met."""
        msg_id = job.msg_id
        self.log.debug("Attempting to assign task %s", msg_id)
        if not self.targets:
            # no engines, definitely can't run
            return False

        if job.follow or job.targets or job.blacklist or self.hwm:
            # we need a can_run filter
            def can_run(idx):
                # check hwm
                if self.hwm and self.loads[idx] == self.hwm:
                    return False
                target = self.targets[idx]
                # check blacklist
                if target in job.blacklist:
                    return False
                # check targets
                if job.targets and target not in job.targets:
                    return False
                # check follow
                return job.follow.check(self.completed[target],
                                        self.failed[target])

            indices = filter(can_run, range(len(self.targets)))

            if not indices:
                # couldn't run
                if job.follow.all:
                    # check follow for impossibility
                    dests = set()
                    relevant = set()
                    if job.follow.success:
                        relevant = self.all_completed
                    if job.follow.failure:
                        relevant = relevant.union(self.all_failed)
                    for m in job.follow.intersection(relevant):
                        dests.add(self.destinations[m])
                    if len(dests) > 1:
                        self.depending[msg_id] = job
                        self.fail_unreachable(msg_id)
                        return False
                if job.targets:
                    # check blacklist+targets for impossibility
                    job.targets.difference_update(job.blacklist)
                    if not job.targets or not job.targets.intersection(
                            self.targets):
                        self.depending[msg_id] = job
                        self.fail_unreachable(msg_id)
                        return False
                return False
        else:
            indices = None

        self.submit_task(job, indices)
        return True

    def save_unmet(self, job):
        """Save a message for later submission when its dependencies are met."""
        msg_id = job.msg_id
        self.depending[msg_id] = job
        # track the ids in follow or after, but not those already finished
        for dep_id in job.after.union(job.follow).difference(self.all_done):
            if dep_id not in self.graph:
                self.graph[dep_id] = set()
            self.graph[dep_id].add(msg_id)

    def submit_task(self, job, indices=None):
        """Submit a task to any of a subset of our targets."""
        if indices:
            loads = [self.loads[i] for i in indices]
        else:
            loads = self.loads
        idx = self.scheme(loads)
        if indices:
            idx = indices[idx]
        target = self.targets[idx]
        # print (target, map(str, msg[:3]))
        # send job to the engine
        self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
        self.engine_stream.send_multipart(job.raw_msg, copy=False)
        # update load
        self.add_job(idx)
        self.pending[target][job.msg_id] = job
        # notify Hub
        content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
        self.session.send(self.mon_stream,
                          'task_destination',
                          content=content,
                          ident=[b'tracktask', self.ident])

    #-----------------------------------------------------------------------
    # Result Handling
    #-----------------------------------------------------------------------

    @util.log_errors
    def dispatch_result(self, raw_msg):
        """dispatch method for result replies"""
        try:
            idents, msg = self.session.feed_identities(raw_msg, copy=False)
            msg = self.session.unserialize(msg, content=False, copy=False)
            engine = idents[0]
            try:
                idx = self.targets.index(engine)
            except ValueError:
                pass  # skip load-update for dead engines
            else:
                self.finish_job(idx)
        except Exception:
            self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
            return

        md = msg['metadata']
        parent = msg['parent_header']
        if md.get('dependencies_met', True):
            success = (md['status'] == 'ok')
            msg_id = parent['msg_id']
            retries = self.retries[msg_id]
            if not success and retries > 0:
                # failed
                self.retries[msg_id] = retries - 1
                self.handle_unmet_dependency(idents, parent)
            else:
                del self.retries[msg_id]
                # relay to client and update graph
                self.handle_result(idents, parent, raw_msg, success)
                # send to Hub monitor
                self.mon_stream.send_multipart([b'outtask'] + raw_msg,
                                               copy=False)
        else:
            self.handle_unmet_dependency(idents, parent)

    def handle_result(self, idents, parent, raw_msg, success=True):
        """handle a real task result, either success or failure"""
        # first, relay result to client
        engine = idents[0]
        client = idents[1]
        # swap_ids for ROUTER-ROUTER mirror
        raw_msg[:2] = [client, engine]
        # print (map(str, raw_msg[:4]))
        self.client_stream.send_multipart(raw_msg, copy=False)
        # now, update our data structures
        msg_id = parent['msg_id']
        self.pending[engine].pop(msg_id)
        if success:
            self.completed[engine].add(msg_id)
            self.all_completed.add(msg_id)
        else:
            self.failed[engine].add(msg_id)
            self.all_failed.add(msg_id)
        self.all_done.add(msg_id)
        self.destinations[msg_id] = engine

        self.update_graph(msg_id, success)

    def handle_unmet_dependency(self, idents, parent):
        """handle an unmet dependency"""
        engine = idents[0]
        msg_id = parent['msg_id']

        job = self.pending[engine].pop(msg_id)
        job.blacklist.add(engine)

        if job.blacklist == job.targets:
            self.depending[msg_id] = job
            self.fail_unreachable(msg_id)
        elif not self.maybe_run(job):
            # resubmit failed
            if msg_id not in self.all_failed:
                # put it back in our dependency tree
                self.save_unmet(job)

        if self.hwm:
            try:
                idx = self.targets.index(engine)
            except ValueError:
                pass  # skip load-update for dead engines
            else:
                if self.loads[idx] == self.hwm - 1:
                    self.update_graph(None)

    def update_graph(self, dep_id=None, success=True):
        """dep_id just finished. Update our dependency
        graph and submit any jobs that just became runable.

        Called with dep_id=None to update entire graph for hwm, but without finishing
        a task.
        """
        # print ("\n\n***********")
        # pprint (dep_id)
        # pprint (self.graph)
        # pprint (self.depending)
        # pprint (self.all_completed)
        # pprint (self.all_failed)
        # print ("\n\n***********\n\n")
        # update any jobs that depended on the dependency
        jobs = self.graph.pop(dep_id, [])

        # recheck *all* jobs if
        # a) we have HWM and an engine just become no longer full
        # or b) dep_id was given as None

        if dep_id is None or self.hwm and any(
            [load == self.hwm - 1 for load in self.loads]):
            jobs = self.depending.keys()

        for msg_id in sorted(
                jobs, key=lambda msg_id: self.depending[msg_id].timestamp):
            job = self.depending[msg_id]

            if job.after.unreachable(self.all_completed, self.all_failed)\
                    or job.follow.unreachable(self.all_completed, self.all_failed):
                self.fail_unreachable(msg_id)

            elif job.after.check(self.all_completed,
                                 self.all_failed):  # time deps met, maybe run
                if self.maybe_run(job):

                    self.depending.pop(msg_id)
                    for mid in job.dependents:
                        if mid in self.graph:
                            self.graph[mid].remove(msg_id)

    #----------------------------------------------------------------------
    # methods to be overridden by subclasses
    #----------------------------------------------------------------------

    def add_job(self, idx):
        """Called after self.targets[idx] just got the job with header.
        Override with subclasses.  The default ordering is simple LRU.
        The default loads are the number of outstanding jobs."""
        self.loads[idx] += 1
        for lis in (self.targets, self.loads):
            lis.append(lis.pop(idx))

    def finish_job(self, idx):
        """Called after self.targets[idx] just finished a job.
        Override with subclasses."""
        self.loads[idx] -= 1
예제 #6
0
class Authenticator(LoggingConfigurable):
    """A class for authentication.
    
    The API is one method, `authenticate`, a tornado gen.coroutine.
    """
    
    db = Any()
    whitelist = Set(config=True,
        help="""Username whitelist.
        
        Use this to restrict which users can login.
        If empty, allow any user to attempt login.
        """
    )
    custom_html = Unicode('',
        help="""HTML login form for custom handlers.
        Override in form-based custom authenticators
        that don't use username+password,
        or need custom branding.
        """
    )
    login_service = Unicode('',
        help="""Name of the login service for external
        login services (e.g. 'GitHub').
        """
    )
    
    @gen.coroutine
    def authenticate(self, handler, data):
        """Authenticate a user with login form data.
        
        This must be a tornado gen.coroutine.
        It must return the username on successful authentication,
        and return None on failed authentication.
        """

    def check_whitelist(self, user):
        """
        Return True if the whitelist is empty or user is in the whitelist.
        """
        # Parens aren't necessary here, but they make this easier to parse.
        return (not self.whitelist) or (user in self.whitelist)

    def add_user(self, user):
        """Add a new user
        
        By default, this just adds the user to the whitelist.
        
        Subclasses may do more extensive things,
        such as adding actual unix users.
        """
        if self.whitelist:
            self.whitelist.add(user.name)
    
    def delete_user(self, user):
        """Triggered when a user is deleted.
        
        Removes the user from the whitelist.
        """
        self.whitelist.discard(user.name)
    
    def login_url(self, base_url):
        """Override to register a custom login handler"""
        return url_path_join(base_url, 'login')
    
    def logout_url(self, base_url):
        """Override to register a custom logout handler"""
        return url_path_join(base_url, 'logout')
    
    def get_handlers(self, app):
        """Return any custom handlers the authenticator needs to register
        
        (e.g. for OAuth)
        """
        return [
            ('/login', LoginHandler),
        ]
예제 #7
0
class Session(Configurable):
    """Object for handling serialization and sending of messages.
    
    The Session object handles building messages and sending them
    with ZMQ sockets or ZMQStream objects.  Objects can communicate with each
    other over the network via Session objects, and only need to work with the
    dict-based IPython message spec. The Session will handle 
    serialization/deserialization, security, and metadata.
    
    Sessions support configurable serialiization via packer/unpacker traits,
    and signing with HMAC digests via the key/keyfile traits.
    
    Parameters
    ----------
    
    debug : bool
        whether to trigger extra debugging statements
    packer/unpacker : str : 'json', 'pickle' or import_string
        importstrings for methods to serialize message parts.  If just
        'json' or 'pickle', predefined JSON and pickle packers will be used.
        Otherwise, the entire importstring must be used.
        
        The functions must accept at least valid JSON input, and output *bytes*.
        
        For example, to use msgpack:
        packer = 'msgpack.packb', unpacker='msgpack.unpackb'
    pack/unpack : callables
        You can also set the pack/unpack callables for serialization directly.
    session : bytes
        the ID of this Session object.  The default is to generate a new UUID.
    username : unicode
        username added to message headers.  The default is to ask the OS.
    key : bytes
        The key used to initialize an HMAC signature.  If unset, messages
        will not be signed or checked.
    keyfile : filepath
        The file containing a key.  If this is set, `key` will be initialized
        to the contents of the file.
    
    """

    debug = Bool(False, config=True, help="""Debug output in the Session""")

    packer = DottedObjectName(
        'json',
        config=True,
        help="""The name of the packer for serializing messages.
            Should be one of 'json', 'pickle', or an import name
            for a custom callable serializer.""")

    def _packer_changed(self, name, old, new):
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
        else:
            self.pack = import_item(str(new))

    unpacker = DottedObjectName(
        'json',
        config=True,
        help="""The name of the unpacker for unserializing messages.
        Only used with custom functions for `packer`.""")

    def _unpacker_changed(self, name, old, new):
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
        else:
            self.unpack = import_item(str(new))

    session = CBytes(b'',
                     config=True,
                     help="""The UUID identifying this session.""")

    def _session_default(self):
        return bytes(uuid.uuid4())

    username = Unicode(
        os.environ.get('USER', 'username'),
        config=True,
        help="""Username for the Session. Default is your system username.""")

    # message signature related traits:
    key = CBytes(b'',
                 config=True,
                 help="""execution key, for extra authentication.""")

    def _key_changed(self, name, old, new):
        if new:
            self.auth = hmac.HMAC(new)
        else:
            self.auth = None

    auth = Instance(hmac.HMAC)
    digest_history = Set()

    keyfile = Unicode('',
                      config=True,
                      help="""path to file containing execution key.""")

    def _keyfile_changed(self, name, old, new):
        with open(new, 'rb') as f:
            self.key = f.read().strip()

    pack = Any(default_packer)  # the actual packer function

    def _pack_changed(self, name, old, new):
        if not callable(new):
            raise TypeError("packer must be callable, not %s" % type(new))

    unpack = Any(default_unpacker)  # the actual packer function

    def _unpack_changed(self, name, old, new):
        # unpacker is not checked - it is assumed to be
        if not callable(new):
            raise TypeError("unpacker must be callable, not %s" % type(new))

    def __init__(self, **kwargs):
        """create a Session object
        
        Parameters
        ----------

        debug : bool
            whether to trigger extra debugging statements
        packer/unpacker : str : 'json', 'pickle' or import_string
            importstrings for methods to serialize message parts.  If just
            'json' or 'pickle', predefined JSON and pickle packers will be used.
            Otherwise, the entire importstring must be used.

            The functions must accept at least valid JSON input, and output
            *bytes*.

            For example, to use msgpack:
            packer = 'msgpack.packb', unpacker='msgpack.unpackb'
        pack/unpack : callables
            You can also set the pack/unpack callables for serialization
            directly.
        session : bytes
            the ID of this Session object.  The default is to generate a new 
            UUID.
        username : unicode
            username added to message headers.  The default is to ask the OS.
        key : bytes
            The key used to initialize an HMAC signature.  If unset, messages
            will not be signed or checked.
        keyfile : filepath
            The file containing a key.  If this is set, `key` will be 
            initialized to the contents of the file.
        """
        super(Session, self).__init__(**kwargs)
        self._check_packers()
        self.none = self.pack({})

    @property
    def msg_id(self):
        """always return new uuid"""
        return str(uuid.uuid4())

    def _check_packers(self):
        """check packers for binary data and datetime support."""
        pack = self.pack
        unpack = self.unpack

        # check simple serialization
        msg = dict(a=[1, 'hi'])
        try:
            packed = pack(msg)
        except Exception:
            raise ValueError("packer could not serialize a simple message")

        # ensure packed message is bytes
        if not isinstance(packed, bytes):
            raise ValueError("message packed to %r, but bytes are required" %
                             type(packed))

        # check that unpack is pack's inverse
        try:
            unpacked = unpack(packed)
        except Exception:
            raise ValueError("unpacker could not handle the packer's output")

        # check datetime support
        msg = dict(t=datetime.now())
        try:
            unpacked = unpack(pack(msg))
        except Exception:
            self.pack = lambda o: pack(squash_dates(o))
            self.unpack = lambda s: extract_dates(unpack(s))

    def msg_header(self, msg_type):
        return msg_header(self.msg_id, msg_type, self.username, self.session)

    def msg(self, msg_type, content=None, parent=None, subheader=None):
        """Return the nested message dict.

        This format is different from what is sent over the wire. The
        self.serialize method converts this nested message dict to the wire
        format, which uses a message list.
        """
        msg = {}
        msg['header'] = self.msg_header(msg_type)
        msg['msg_id'] = msg['header']['msg_id']
        msg['parent_header'] = {} if parent is None else extract_header(parent)
        msg['msg_type'] = msg_type
        msg['content'] = {} if content is None else content
        sub = {} if subheader is None else subheader
        msg['header'].update(sub)
        return msg

    def sign(self, msg_list):
        """Sign a message with HMAC digest. If no auth, return b''.

        Parameters
        ----------
        msg_list : list
            The [p_header,p_parent,p_content] part of the message list. 
        """
        if self.auth is None:
            return b''
        h = self.auth.copy()
        for m in msg_list:
            h.update(m)
        return h.hexdigest()

    def serialize(self, msg, ident=None):
        """Serialize the message components to bytes.

        Parameters
        ----------
        msg : dict or Message
            The nexted message dict as returned by the self.msg method.

        Returns
        -------
        msg_list : list
            The list of bytes objects to be sent with the format:
            [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
             buffer1,buffer2,...]. In this list, the p_* entities are
            the packed or serialized versions, so if JSON is used, these
            are uft8 encoded JSON strings.
        """
        content = msg.get('content', {})
        if content is None:
            content = self.none
        elif isinstance(content, dict):
            content = self.pack(content)
        elif isinstance(content, bytes):
            # content is already packed, as in a relayed message
            pass
        elif isinstance(content, unicode):
            # should be bytes, but JSON often spits out unicode
            content = content.encode('utf8')
        else:
            raise TypeError("Content incorrect type: %s" % type(content))

        real_message = [
            self.pack(msg['header']),
            self.pack(msg['parent_header']), content
        ]

        to_send = []

        if isinstance(ident, list):
            # accept list of idents
            to_send.extend(ident)
        elif ident is not None:
            to_send.append(ident)
        to_send.append(DELIM)

        signature = self.sign(real_message)
        to_send.append(signature)

        to_send.extend(real_message)

        return to_send

    def send(self,
             stream,
             msg_or_type,
             content=None,
             parent=None,
             ident=None,
             buffers=None,
             subheader=None,
             track=False):
        """Build and send a message via stream or socket.

        The message format used by this function internally is as follows:

        [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
         buffer1,buffer2,...] 

        The self.serialize method converts the nested message dict into this
        format.

        Parameters
        ----------
        
        stream : zmq.Socket or ZMQStream
            the socket-like object used to send the data
        msg_or_type : str or Message/dict
            Normally, msg_or_type will be a msg_type unless a message is being 
            sent more than once.
        
        content : dict or None
            the content of the message (ignored if msg_or_type is a message)
        parent : Message or dict or None
            the parent or parent header describing the parent of this message
        ident : bytes or list of bytes
            the zmq.IDENTITY routing path
        subheader : dict or None
            extra header keys for this message's header
        buffers : list or None
            the already-serialized buffers to be appended to the message
        track : bool
            whether to track.  Only for use with Sockets, 
            because ZMQStream objects cannot track messages.
        
        Returns
        -------
        msg : message dict
            the constructed message
        (msg,tracker) : (message dict, MessageTracker)
            if track=True, then a 2-tuple will be returned, 
            the first element being the constructed
            message, and the second being the MessageTracker
            
        """

        if not isinstance(stream, (zmq.Socket, ZMQStream)):
            raise TypeError("stream must be Socket or ZMQStream, not %r" %
                            type(stream))
        elif track and isinstance(stream, ZMQStream):
            raise TypeError("ZMQStream cannot track messages")

        if isinstance(msg_or_type, (Message, dict)):
            # we got a Message, not a msg_type
            # don't build a new Message
            msg = msg_or_type
        else:
            msg = self.msg(msg_or_type, content, parent, subheader)

        buffers = [] if buffers is None else buffers
        to_send = self.serialize(msg, ident)
        flag = 0
        if buffers:
            flag = zmq.SNDMORE
            _track = False
        else:
            _track = track
        if track:
            tracker = stream.send_multipart(to_send,
                                            flag,
                                            copy=False,
                                            track=_track)
        else:
            tracker = stream.send_multipart(to_send, flag, copy=False)
        for b in buffers[:-1]:
            stream.send(b, flag, copy=False)
        if buffers:
            if track:
                tracker = stream.send(buffers[-1], copy=False, track=track)
            else:
                tracker = stream.send(buffers[-1], copy=False)

        # omsg = Message(msg)
        if self.debug:
            pprint.pprint(msg)
            pprint.pprint(to_send)
            pprint.pprint(buffers)

        msg['tracker'] = tracker

        return msg

    def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
        """Send a raw message via ident path.

        This method is used to send a already serialized message.

        Parameters
        ----------
        stream : ZMQStream or Socket
            The ZMQ stream or socket to use for sending the message.
        msg_list : list
            The serialized list of messages to send. This only includes the
            [p_header,p_parent,p_content,buffer1,buffer2,...] portion of
            the message.
        ident : ident or list
            A single ident or a list of idents to use in sending.
        """
        to_send = []
        if isinstance(ident, bytes):
            ident = [ident]
        if ident is not None:
            to_send.extend(ident)

        to_send.append(DELIM)
        to_send.append(self.sign(msg_list))
        to_send.extend(msg_list)
        stream.send_multipart(msg_list, flags, copy=copy)

    def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
        """Receive and unpack a message.

        Parameters
        ----------
        socket : ZMQStream or Socket
            The socket or stream to use in receiving.

        Returns
        -------
        [idents], msg
            [idents] is a list of idents and msg is a nested message dict of
            same format as self.msg returns.
        """
        if isinstance(socket, ZMQStream):
            socket = socket.socket
        try:
            msg_list = socket.recv_multipart(mode)
        except zmq.ZMQError as e:
            if e.errno == zmq.EAGAIN:
                # We can convert EAGAIN to None as we know in this case
                # recv_multipart won't return None.
                return None, None
            else:
                raise
        # split multipart message into identity list and message dict
        # invalid large messages can cause very expensive string comparisons
        idents, msg_list = self.feed_identities(msg_list, copy)
        try:
            return idents, self.unpack_message(msg_list,
                                               content=content,
                                               copy=copy)
        except Exception as e:
            print(idents, msg_list)
            # TODO: handle it
            raise e

    def feed_identities(self, msg_list, copy=True):
        """Split the identities from the rest of the message.

        Feed until DELIM is reached, then return the prefix as idents and
        remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
        but that would be silly.
        
        Parameters
        ----------
        msg_list : a list of Message or bytes objects
            The message to be split.
        copy : bool
            flag determining whether the arguments are bytes or Messages
        
        Returns
        -------
        (idents,msg_list) : two lists
            idents will always be a list of bytes - the indentity prefix
            msg_list will be a list of bytes or Messages, unchanged from input
            msg_list should be unpackable via self.unpack_message at this point.
        """
        if copy:
            idx = msg_list.index(DELIM)
            return msg_list[:idx], msg_list[idx + 1:]
        else:
            failed = True
            for idx, m in enumerate(msg_list):
                if m.bytes == DELIM:
                    failed = False
                    break
            if failed:
                raise ValueError("DELIM not in msg_list")
            idents, msg_list = msg_list[:idx], msg_list[idx + 1:]
            return [m.bytes for m in idents], msg_list

    def unpack_message(self, msg_list, content=True, copy=True):
        """Return a message object from the format
        sent by self.send.
        
        Parameters:
        -----------
        
        content : bool (True)
            whether to unpack the content dict (True), 
            or leave it serialized (False)
        
        copy : bool (True)
            whether to return the bytes (True), 
            or the non-copying Message object in each place (False)
        
        """
        minlen = 4
        message = {}
        if not copy:
            for i in range(minlen):
                msg_list[i] = msg_list[i].bytes
        if self.auth is not None:
            signature = msg_list[0]
            if signature in self.digest_history:
                raise ValueError("Duplicate Signature: %r" % signature)
            self.digest_history.add(signature)
            check = self.sign(msg_list[1:4])
            if not signature == check:
                raise ValueError("Invalid Signature: %r" % signature)
        if not len(msg_list) >= minlen:
            raise TypeError(
                "malformed message, must have at least %i elements" % minlen)
        message['header'] = self.unpack(msg_list[1])
        message['msg_type'] = message['header']['msg_type']
        message['parent_header'] = self.unpack(msg_list[2])
        if content:
            message['content'] = self.unpack(msg_list[3])
        else:
            message['content'] = msg_list[3]

        message['buffers'] = msg_list[4:]
        return message
예제 #8
0
class View(HasTraits):
    """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.

    Don't use this class, use subclasses.

    Methods
    -------

    spin
        flushes incoming results and registration state changes
        control methods spin, and requesting `ids` also ensures up to date

    wait
        wait on one or more msg_ids

    execution methods
        apply
        legacy: execute, run

    data movement
        push, pull, scatter, gather

    query methods
        get_result, queue_status, purge_results, result_status

    control methods
        abort, shutdown

    """
    # flags
    block = Bool(False)
    track = Bool(True)
    targets = Any()

    history = List()
    outstanding = Set()
    results = Dict()
    client = Instance('IPython.parallel.Client')

    _socket = Instance('zmq.Socket')
    _flag_names = List(['targets', 'block', 'track'])
    _targets = Any()
    _idents = Any()

    def __init__(self, client=None, socket=None, **flags):
        super(View, self).__init__(client=client, _socket=socket)
        self.results = client.results
        self.block = client.block

        self.set_flags(**flags)

        assert not self.__class__ is View, "Don't use base View objects, use subclasses"

    def __repr__(self):
        strtargets = str(self.targets)
        if len(strtargets) > 16:
            strtargets = strtargets[:12] + '...]'
        return "<%s %s>" % (self.__class__.__name__, strtargets)

    def __len__(self):
        if isinstance(self.targets, list):
            return len(self.targets)
        elif isinstance(self.targets, int):
            return 1
        else:
            return len(self.client)

    def set_flags(self, **kwargs):
        """set my attribute flags by keyword.

        Views determine behavior with a few attributes (`block`, `track`, etc.).
        These attributes can be set all at once by name with this method.

        Parameters
        ----------

        block : bool
            whether to wait for results
        track : bool
            whether to create a MessageTracker to allow the user to
            safely edit after arrays and buffers during non-copying
            sends.
        """
        for name, value in kwargs.iteritems():
            if name not in self._flag_names:
                raise KeyError("Invalid name: %r" % name)
            else:
                setattr(self, name, value)

    @contextmanager
    def temp_flags(self, **kwargs):
        """temporarily set flags, for use in `with` statements.

        See set_flags for permanent setting of flags

        Examples
        --------

        >>> view.track=False
        ...
        >>> with view.temp_flags(track=True):
        ...    ar = view.apply(dostuff, my_big_array)
        ...    ar.tracker.wait() # wait for send to finish
        >>> view.track
        False

        """
        # preflight: save flags, and set temporaries
        saved_flags = {}
        for f in self._flag_names:
            saved_flags[f] = getattr(self, f)
        self.set_flags(**kwargs)
        # yield to the with-statement block
        try:
            yield
        finally:
            # postflight: restore saved flags
            self.set_flags(**saved_flags)

    #----------------------------------------------------------------
    # apply
    #----------------------------------------------------------------

    @sync_results
    @save_ids
    def _really_apply(self, f, args, kwargs, block=None, **options):
        """wrapper for client.send_apply_request"""
        raise NotImplementedError("Implement in subclasses")

    def apply(self, f, *args, **kwargs):
        """calls f(*args, **kwargs) on remote engines, returning the result.

        This method sets all apply flags via this View's attributes.

        if self.block is False:
            returns AsyncResult
        else:
            returns actual result of f(*args, **kwargs)
        """
        return self._really_apply(f, args, kwargs)

    def apply_async(self, f, *args, **kwargs):
        """calls f(*args, **kwargs) on remote engines in a nonblocking manner.

        returns AsyncResult
        """
        return self._really_apply(f, args, kwargs, block=False)

    @spin_after
    def apply_sync(self, f, *args, **kwargs):
        """calls f(*args, **kwargs) on remote engines in a blocking manner,
         returning the result.

        returns: actual result of f(*args, **kwargs)
        """
        return self._really_apply(f, args, kwargs, block=True)

    #----------------------------------------------------------------
    # wrappers for client and control methods
    #----------------------------------------------------------------
    @sync_results
    def spin(self):
        """spin the client, and sync"""
        self.client.spin()

    @sync_results
    def wait(self, jobs=None, timeout=-1):
        """waits on one or more `jobs`, for up to `timeout` seconds.

        Parameters
        ----------

        jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
                ints are indices to self.history
                strs are msg_ids
                default: wait on all outstanding messages
        timeout : float
                a time in seconds, after which to give up.
                default is -1, which means no timeout

        Returns
        -------

        True : when all msg_ids are done
        False : timeout reached, some msg_ids still outstanding
        """
        if jobs is None:
            jobs = self.history
        return self.client.wait(jobs, timeout)

    def abort(self, jobs=None, targets=None, block=None):
        """Abort jobs on my engines.

        Parameters
        ----------

        jobs : None, str, list of strs, optional
            if None: abort all jobs.
            else: abort specific msg_id(s).
        """
        block = block if block is not None else self.block
        targets = targets if targets is not None else self.targets
        jobs = jobs if jobs is not None else list(self.outstanding)

        return self.client.abort(jobs=jobs, targets=targets, block=block)

    def queue_status(self, targets=None, verbose=False):
        """Fetch the Queue status of my engines"""
        targets = targets if targets is not None else self.targets
        return self.client.queue_status(targets=targets, verbose=verbose)

    def purge_results(self, jobs=[], targets=[]):
        """Instruct the controller to forget specific results."""
        if targets is None or targets == 'all':
            targets = self.targets
        return self.client.purge_results(jobs=jobs, targets=targets)

    def shutdown(self, targets=None, restart=False, hub=False, block=None):
        """Terminates one or more engine processes, optionally including the hub.
        """
        block = self.block if block is None else block
        if targets is None or targets == 'all':
            targets = self.targets
        return self.client.shutdown(targets=targets,
                                    restart=restart,
                                    hub=hub,
                                    block=block)

    @spin_after
    def get_result(self, indices_or_msg_ids=None):
        """return one or more results, specified by history index or msg_id.

        See client.get_result for details.

        """

        if indices_or_msg_ids is None:
            indices_or_msg_ids = -1
        if isinstance(indices_or_msg_ids, int):
            indices_or_msg_ids = self.history[indices_or_msg_ids]
        elif isinstance(indices_or_msg_ids, (list, tuple, set)):
            indices_or_msg_ids = list(indices_or_msg_ids)
            for i, index in enumerate(indices_or_msg_ids):
                if isinstance(index, int):
                    indices_or_msg_ids[i] = self.history[index]
        return self.client.get_result(indices_or_msg_ids)

    #-------------------------------------------------------------------
    # Map
    #-------------------------------------------------------------------

    def map(self, f, *sequences, **kwargs):
        """override in subclasses"""
        raise NotImplementedError

    def map_async(self, f, *sequences, **kwargs):
        """Parallel version of builtin `map`, using this view's engines.

        This is equivalent to map(...block=False)

        See `self.map` for details.
        """
        if 'block' in kwargs:
            raise TypeError(
                "map_async doesn't take a `block` keyword argument.")
        kwargs['block'] = False
        return self.map(f, *sequences, **kwargs)

    def map_sync(self, f, *sequences, **kwargs):
        """Parallel version of builtin `map`, using this view's engines.

        This is equivalent to map(...block=True)

        See `self.map` for details.
        """
        if 'block' in kwargs:
            raise TypeError(
                "map_sync doesn't take a `block` keyword argument.")
        kwargs['block'] = True
        return self.map(f, *sequences, **kwargs)

    def imap(self, f, *sequences, **kwargs):
        """Parallel version of `itertools.imap`.

        See `self.map` for details.

        """

        return iter(self.map_async(f, *sequences, **kwargs))

    #-------------------------------------------------------------------
    # Decorators
    #-------------------------------------------------------------------

    def remote(self, block=True, **flags):
        """Decorator for making a RemoteFunction"""
        block = self.block if block is None else block
        return remote(self, block=block, **flags)

    def parallel(self, dist='b', block=None, **flags):
        """Decorator for making a ParallelFunction"""
        block = self.block if block is None else block
        return parallel(self, dist=dist, block=block, **flags)
예제 #9
0
class Widget(LoggingConfigurable):
    #-------------------------------------------------------------------------
    # Class attributes
    #-------------------------------------------------------------------------
    _widget_construction_callback = None
    widgets = {}
    widget_types = {}

    @staticmethod
    def on_widget_constructed(callback):
        """Registers a callback to be called when a widget is constructed.

        The callback must have the following signature:
        callback(widget)"""
        Widget._widget_construction_callback = callback

    @staticmethod
    def _call_widget_constructed(widget):
        """Static method, called when a widget is constructed."""
        if Widget._widget_construction_callback is not None and callable(Widget._widget_construction_callback):
            Widget._widget_construction_callback(widget)

    @staticmethod
    def handle_comm_opened(comm, msg):
        """Static method, called when a widget is constructed."""
        widget_class = import_item(msg['content']['data']['widget_class'])
        widget = widget_class(comm=comm)


    #-------------------------------------------------------------------------
    # Traits
    #-------------------------------------------------------------------------
    _model_module = Unicode(None, allow_none=True, help="""A requirejs module name
        in which to find _model_name. If empty, look in the global registry.""")
    _model_name = Unicode('WidgetModel', help="""Name of the backbone model 
        registered in the front-end to create and sync this widget with.""")
    _view_module = Unicode(help="""A requirejs module in which to find _view_name.
        If empty, look in the global registry.""", sync=True)
    _view_name = Unicode(None, allow_none=True, help="""Default view registered in the front-end
        to use to represent the widget.""", sync=True)
    comm = Instance('IPython.kernel.comm.Comm')
    
    msg_throttle = Int(3, sync=True, help="""Maximum number of msgs the 
        front-end can send before receiving an idle msg from the back-end.""")
    
    version = Int(0, sync=True, help="""Widget's version""")
    keys = List()
    def _keys_default(self):
        return [name for name in self.traits(sync=True)]
    
    _property_lock = Tuple((None, None))
    _send_state_lock = Int(0)
    _states_to_send = Set(allow_none=False)
    _display_callbacks = Instance(CallbackDispatcher, ())
    _msg_callbacks = Instance(CallbackDispatcher, ())
    
    #-------------------------------------------------------------------------
    # (Con/de)structor
    #-------------------------------------------------------------------------
    def __init__(self, **kwargs):
        """Public constructor"""
        self._model_id = kwargs.pop('model_id', None)
        super(Widget, self).__init__(**kwargs)

        Widget._call_widget_constructed(self)
        self.open()

    def __del__(self):
        """Object disposal"""
        self.close()

    #-------------------------------------------------------------------------
    # Properties
    #-------------------------------------------------------------------------

    def open(self):
        """Open a comm to the frontend if one isn't already open."""
        if self.comm is None:
            args = dict(target_name='ipython.widget',
                        data={'model_name': self._model_name,
                              'model_module': self._model_module})
            if self._model_id is not None:
                args['comm_id'] = self._model_id
            self.comm = Comm(**args)

    def _comm_changed(self, name, new):
        """Called when the comm is changed."""
        if new is None:
            return
        self._model_id = self.model_id
        
        self.comm.on_msg(self._handle_msg)
        Widget.widgets[self.model_id] = self
        
        # first update
        self.send_state()

    @property
    def model_id(self):
        """Gets the model id of this widget.

        If a Comm doesn't exist yet, a Comm will be created automagically."""
        return self.comm.comm_id

    #-------------------------------------------------------------------------
    # Methods
    #-------------------------------------------------------------------------

    def close(self):
        """Close method.

        Closes the underlying comm.
        When the comm is closed, all of the widget views are automatically
        removed from the front-end."""
        if self.comm is not None:
            Widget.widgets.pop(self.model_id, None)
            self.comm.close()
            self.comm = None
    
    def send_state(self, key=None):
        """Sends the widget state, or a piece of it, to the front-end.

        Parameters
        ----------
        key : unicode, or iterable (optional)
            A single property's name or iterable of property names to sync with the front-end.
        """
        self._send({
            "method" : "update",
            "state"  : self.get_state(key=key)
        })

    def get_state(self, key=None):
        """Gets the widget state, or a piece of it.

        Parameters
        ----------
        key : unicode or iterable (optional)
            A single property's name or iterable of property names to get.
        """
        if key is None:
            keys = self.keys
        elif isinstance(key, string_types):
            keys = [key]
        elif isinstance(key, collections.Iterable):
            keys = key
        else:
            raise ValueError("key must be a string, an iterable of keys, or None")
        state = {}
        for k in keys:
            f = self.trait_metadata(k, 'to_json', self._trait_to_json)
            value = getattr(self, k)
            state[k] = f(value)
        return state

    def set_state(self, sync_data):
        """Called when a state is received from the front-end."""
        for name in self.keys:
            if name in sync_data:
                json_value = sync_data[name]
                from_json = self.trait_metadata(name, 'from_json', self._trait_from_json)
                with self._lock_property(name, json_value):
                    setattr(self, name, from_json(json_value))
    
    def send(self, content):
        """Sends a custom msg to the widget model in the front-end.

        Parameters
        ----------
        content : dict
            Content of the message to send.
        """
        self._send({"method": "custom", "content": content})

    def on_msg(self, callback, remove=False):
        """(Un)Register a custom msg receive callback.

        Parameters
        ----------
        callback: callable
            callback will be passed two arguments when a message arrives::
            
                callback(widget, content)
            
        remove: bool
            True if the callback should be unregistered."""
        self._msg_callbacks.register_callback(callback, remove=remove)

    def on_displayed(self, callback, remove=False):
        """(Un)Register a widget displayed callback.

        Parameters
        ----------
        callback: method handler
            Must have a signature of::
            
                callback(widget, **kwargs)
            
            kwargs from display are passed through without modification.
        remove: bool
            True if the callback should be unregistered."""
        self._display_callbacks.register_callback(callback, remove=remove)

    #-------------------------------------------------------------------------
    # Support methods
    #-------------------------------------------------------------------------
    @contextmanager
    def _lock_property(self, key, value):
        """Lock a property-value pair.

        The value should be the JSON state of the property.

        NOTE: This, in addition to the single lock for all state changes, is 
        flawed.  In the future we may want to look into buffering state changes 
        back to the front-end."""
        self._property_lock = (key, value)
        try:
            yield
        finally:
            self._property_lock = (None, None)

    @contextmanager
    def hold_sync(self):
        """Hold syncing any state until the context manager is released"""
        # We increment a value so that this can be nested.  Syncing will happen when
        # all levels have been released.
        self._send_state_lock += 1
        try:
            yield
        finally:
            self._send_state_lock -=1
            if self._send_state_lock == 0:
                self.send_state(self._states_to_send)
                self._states_to_send.clear()

    def _should_send_property(self, key, value):
        """Check the property lock (property_lock)"""
        to_json = self.trait_metadata(key, 'to_json', self._trait_to_json)
        if (key == self._property_lock[0]
            and to_json(value) == self._property_lock[1]):
            return False
        elif self._send_state_lock > 0:
            self._states_to_send.add(key)
            return False
        else:
            return True
    
    # Event handlers
    @_show_traceback
    def _handle_msg(self, msg):
        """Called when a msg is received from the front-end"""
        data = msg['content']['data']
        method = data['method']

        # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one.
        if method == 'backbone':
            if 'sync_data' in data:
                sync_data = data['sync_data']
                self.set_state(sync_data) # handles all methods

        # Handle a state request.
        elif method == 'request_state':
            self.send_state()

        # Handle a custom msg from the front-end.
        elif method == 'custom':
            if 'content' in data:
                self._handle_custom_msg(data['content'])

        # Catch remainder.
        else:
            self.log.error('Unknown front-end to back-end widget msg with method "%s"' % method)

    def _handle_custom_msg(self, content):
        """Called when a custom msg is received."""
        self._msg_callbacks(self, content)

    def _notify_trait(self, name, old_value, new_value):
        """Called when a property has been changed."""
        # Trigger default traitlet callback machinery.  This allows any user
        # registered validation to be processed prior to allowing the widget
        # machinery to handle the state.
        LoggingConfigurable._notify_trait(self, name, old_value, new_value)

        # Send the state after the user registered callbacks for trait changes
        # have all fired (allows for user to validate values).
        if self.comm is not None and name in self.keys:
            # Make sure this isn't information that the front-end just sent us.
            if self._should_send_property(name, new_value):
                # Send new state to front-end
                self.send_state(key=name)

    def _handle_displayed(self, **kwargs):
        """Called when a view has been displayed for this widget instance"""
        self._display_callbacks(self, **kwargs)

    def _trait_to_json(self, x):
        """Convert a trait value to json

        Traverse lists/tuples and dicts and serialize their values as well.
        Replace any widgets with their model_id
        """
        if isinstance(x, dict):
            return {k: self._trait_to_json(v) for k, v in x.items()}
        elif isinstance(x, (list, tuple)):
            return [self._trait_to_json(v) for v in x]
        elif isinstance(x, Widget):
            return "IPY_MODEL_" + x.model_id
        else:
            return x # Value must be JSON-able

    def _trait_from_json(self, x):
        """Convert json values to objects

        Replace any strings representing valid model id values to Widget references.
        """
        if isinstance(x, dict):
            return {k: self._trait_from_json(v) for k, v in x.items()}
        elif isinstance(x, (list, tuple)):
            return [self._trait_from_json(v) for v in x]
        elif isinstance(x, string_types) and x.startswith('IPY_MODEL_') and x[10:] in Widget.widgets:
            # we want to support having child widgets at any level in a hierarchy
            # trusting that a widget UUID will not appear out in the wild
            return Widget.widgets[x[10:]]
        else:
            return x

    def _ipython_display_(self, **kwargs):
        """Called when `IPython.display.display` is called on the widget."""
        # Show view.
        if self._view_name is not None:
            self._send({"method": "display"})
            self._handle_displayed(**kwargs)

    def _send(self, msg):
        """Sends a message to the model in the front-end."""
        self.comm.send(msg)
예제 #10
0
class TaskScheduler(SessionFactory):
    """Python TaskScheduler object.
    
    This is the simplest object that supports msg_id based
    DAG dependencies. *Only* task msg_ids are checked, not
    msg_ids of jobs submitted via the MUX queue.
    
    """

    # input arguments:
    scheme = Instance(
        FunctionType,
        default=leastload)  # function for determining the destination
    client_stream = Instance(zmqstream.ZMQStream)  # client-facing stream
    engine_stream = Instance(zmqstream.ZMQStream)  # engine-facing stream
    notifier_stream = Instance(zmqstream.ZMQStream)  # hub-facing sub stream
    mon_stream = Instance(zmqstream.ZMQStream)  # hub-facing pub stream

    # internals:
    graph = Dict()  # dict by msg_id of [ msg_ids that depend on key ]
    depending = Dict()  # dict by msg_id of (msg_id, raw_msg, after, follow)
    pending = Dict()  # dict by engine_uuid of submitted tasks
    completed = Dict()  # dict by engine_uuid of completed tasks
    failed = Dict()  # dict by engine_uuid of failed tasks
    destinations = Dict(
    )  # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
    clients = Dict()  # dict by msg_id for who submitted the task
    targets = List()  # list of target IDENTs
    loads = List()  # list of engine loads
    all_completed = Set()  # set of all completed tasks
    all_failed = Set()  # set of all failed tasks
    all_done = Set()  # set of all finished tasks=union(completed,failed)
    all_ids = Set()  # set of all submitted task IDs
    blacklist = Dict(
    )  # dict by msg_id of locations where a job has encountered UnmetDependency
    auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')

    def start(self):
        self.engine_stream.on_recv(self.dispatch_result, copy=False)
        self._notification_handlers = dict(
            registration_notification=self._register_engine,
            unregistration_notification=self._unregister_engine)
        self.notifier_stream.on_recv(self.dispatch_notification)
        self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3,
                                               self.loop)  # 1 Hz
        self.auditor.start()
        self.log.info("Scheduler started...%r" % self)

    def resume_receiving(self):
        """Resume accepting jobs."""
        self.client_stream.on_recv(self.dispatch_submission, copy=False)

    def stop_receiving(self):
        """Stop accepting jobs while there are no engines.
        Leave them in the ZMQ queue."""
        self.client_stream.on_recv(None)

    #-----------------------------------------------------------------------
    # [Un]Registration Handling
    #-----------------------------------------------------------------------

    def dispatch_notification(self, msg):
        """dispatch register/unregister events."""
        idents, msg = self.session.feed_identities(msg)
        msg = self.session.unpack_message(msg)
        msg_type = msg['msg_type']
        handler = self._notification_handlers.get(msg_type, None)
        if handler is None:
            raise Exception("Unhandled message type: %s" % msg_type)
        else:
            try:
                handler(str(msg['content']['queue']))
            except KeyError:
                self.log.error("task::Invalid notification msg: %s" % msg)

    @logged
    def _register_engine(self, uid):
        """New engine with ident `uid` became available."""
        # head of the line:
        self.targets.insert(0, uid)
        self.loads.insert(0, 0)
        # initialize sets
        self.completed[uid] = set()
        self.failed[uid] = set()
        self.pending[uid] = {}
        if len(self.targets) == 1:
            self.resume_receiving()

    def _unregister_engine(self, uid):
        """Existing engine with ident `uid` became unavailable."""
        if len(self.targets) == 1:
            # this was our only engine
            self.stop_receiving()

        # handle any potentially finished tasks:
        self.engine_stream.flush()

        self.completed.pop(uid)
        self.failed.pop(uid)
        # don't pop destinations, because it might be used later
        # map(self.destinations.pop, self.completed.pop(uid))
        # map(self.destinations.pop, self.failed.pop(uid))

        idx = self.targets.index(uid)
        self.targets.pop(idx)
        self.loads.pop(idx)

        # wait 5 seconds before cleaning up pending jobs, since the results might
        # still be incoming
        if self.pending[uid]:
            dc = ioloop.DelayedCallback(
                lambda: self.handle_stranded_tasks(uid), 5000, self.loop)
            dc.start()

    @logged
    def handle_stranded_tasks(self, engine):
        """Deal with jobs resident in an engine that died."""
        lost = self.pending.pop(engine)

        for msg_id, (raw_msg, targets, MET, follow,
                     timeout) in lost.iteritems():
            self.all_failed.add(msg_id)
            self.all_done.add(msg_id)
            idents, msg = self.session.feed_identities(raw_msg, copy=False)
            msg = self.session.unpack_message(msg, copy=False, content=False)
            parent = msg['header']
            idents = [idents[0], engine] + idents[1:]
            # print (idents)
            try:
                raise error.EngineError(
                    "Engine %r died while running task %r" % (engine, msg_id))
            except:
                content = error.wrap_exception()
            msg = self.session.send(self.client_stream,
                                    'apply_reply',
                                    content,
                                    parent=parent,
                                    ident=idents)
            self.session.send(self.mon_stream, msg, ident=['outtask'] + idents)
            self.update_graph(msg_id)

    #-----------------------------------------------------------------------
    # Job Submission
    #-----------------------------------------------------------------------
    @logged
    def dispatch_submission(self, raw_msg):
        """Dispatch job submission to appropriate handlers."""
        # ensure targets up to date:
        self.notifier_stream.flush()
        try:
            idents, msg = self.session.feed_identities(raw_msg, copy=False)
            msg = self.session.unpack_message(msg, content=False, copy=False)
        except:
            self.log.error("task::Invaid task: %s" % raw_msg, exc_info=True)
            return

        # send to monitor
        self.mon_stream.send_multipart(['intask'] + raw_msg, copy=False)

        header = msg['header']
        msg_id = header['msg_id']
        self.all_ids.add(msg_id)

        # targets
        targets = set(header.get('targets', []))

        # time dependencies
        after = Dependency(header.get('after', []))
        if after.all:
            if after.success:
                after.difference_update(self.all_completed)
            if after.failure:
                after.difference_update(self.all_failed)
        if after.check(self.all_completed, self.all_failed):
            # recast as empty set, if `after` already met,
            # to prevent unnecessary set comparisons
            after = MET

        # location dependencies
        follow = Dependency(header.get('follow', []))

        # turn timeouts into datetime objects:
        timeout = header.get('timeout', None)
        if timeout:
            timeout = datetime.now() + timedelta(0, timeout, 0)

        args = [raw_msg, targets, after, follow, timeout]

        # validate and reduce dependencies:
        for dep in after, follow:
            # check valid:
            if msg_id in dep or dep.difference(self.all_ids):
                self.depending[msg_id] = args
                return self.fail_unreachable(msg_id, error.InvalidDependency)
            # check if unreachable:
            if dep.unreachable(self.all_completed, self.all_failed):
                self.depending[msg_id] = args
                return self.fail_unreachable(msg_id)

        if after.check(self.all_completed, self.all_failed):
            # time deps already met, try to run
            if not self.maybe_run(msg_id, *args):
                # can't run yet
                self.save_unmet(msg_id, *args)
        else:
            self.save_unmet(msg_id, *args)

    # @logged
    def audit_timeouts(self):
        """Audit all waiting tasks for expired timeouts."""
        now = datetime.now()
        for msg_id in self.depending.keys():
            # must recheck, in case one failure cascaded to another:
            if msg_id in self.depending:
                raw, after, targets, follow, timeout = self.depending[msg_id]
                if timeout and timeout < now:
                    self.fail_unreachable(msg_id, timeout=True)

    @logged
    def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
        """a task has become unreachable, send a reply with an ImpossibleDependency
        error."""
        if msg_id not in self.depending:
            self.log.error("msg %r already failed!" % msg_id)
            return
        raw_msg, targets, after, follow, timeout = self.depending.pop(msg_id)
        for mid in follow.union(after):
            if mid in self.graph:
                self.graph[mid].remove(msg_id)

        # FIXME: unpacking a message I've already unpacked, but didn't save:
        idents, msg = self.session.feed_identities(raw_msg, copy=False)
        msg = self.session.unpack_message(msg, copy=False, content=False)
        header = msg['header']

        try:
            raise why()
        except:
            content = error.wrap_exception()

        self.all_done.add(msg_id)
        self.all_failed.add(msg_id)

        msg = self.session.send(self.client_stream,
                                'apply_reply',
                                content,
                                parent=header,
                                ident=idents)
        self.session.send(self.mon_stream, msg, ident=['outtask'] + idents)

        self.update_graph(msg_id, success=False)

    @logged
    def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
        """check location dependencies, and run if they are met."""
        blacklist = self.blacklist.setdefault(msg_id, set())
        if follow or targets or blacklist:
            # we need a can_run filter
            def can_run(idx):
                target = self.targets[idx]
                # check targets
                if targets and target not in targets:
                    return False
                # check blacklist
                if target in blacklist:
                    return False
                # check follow
                return follow.check(self.completed[target],
                                    self.failed[target])

            indices = filter(can_run, range(len(self.targets)))
            if not indices:
                # couldn't run
                if follow.all:
                    # check follow for impossibility
                    dests = set()
                    relevant = set()
                    if follow.success:
                        relevant = self.all_completed
                    if follow.failure:
                        relevant = relevant.union(self.all_failed)
                    for m in follow.intersection(relevant):
                        dests.add(self.destinations[m])
                    if len(dests) > 1:
                        self.fail_unreachable(msg_id)
                        return False
                if targets:
                    # check blacklist+targets for impossibility
                    targets.difference_update(blacklist)
                    if not targets or not targets.intersection(self.targets):
                        self.fail_unreachable(msg_id)
                        return False
                return False
        else:
            indices = None

        self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
        return True

    @logged
    def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
        """Save a message for later submission when its dependencies are met."""
        self.depending[msg_id] = [raw_msg, targets, after, follow, timeout]
        # track the ids in follow or after, but not those already finished
        for dep_id in after.union(follow).difference(self.all_done):
            if dep_id not in self.graph:
                self.graph[dep_id] = set()
            self.graph[dep_id].add(msg_id)

    @logged
    def submit_task(self,
                    msg_id,
                    raw_msg,
                    targets,
                    follow,
                    timeout,
                    indices=None):
        """Submit a task to any of a subset of our targets."""
        if indices:
            loads = [self.loads[i] for i in indices]
        else:
            loads = self.loads
        idx = self.scheme(loads)
        if indices:
            idx = indices[idx]
        target = self.targets[idx]
        # print (target, map(str, msg[:3]))
        self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
        self.engine_stream.send_multipart(raw_msg, copy=False)
        self.add_job(idx)
        self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
        content = dict(msg_id=msg_id, engine_id=target)
        self.session.send(self.mon_stream,
                          'task_destination',
                          content=content,
                          ident=['tracktask', self.session.session])

    #-----------------------------------------------------------------------
    # Result Handling
    #-----------------------------------------------------------------------
    @logged
    def dispatch_result(self, raw_msg):
        """dispatch method for result replies"""
        try:
            idents, msg = self.session.feed_identities(raw_msg, copy=False)
            msg = self.session.unpack_message(msg, content=False, copy=False)
        except:
            self.log.error("task::Invaid result: %s" % raw_msg, exc_info=True)
            return

        header = msg['header']
        if header.get('dependencies_met', True):
            success = (header['status'] == 'ok')
            self.handle_result(idents, msg['parent_header'], raw_msg, success)
            # send to Hub monitor
            self.mon_stream.send_multipart(['outtask'] + raw_msg, copy=False)
        else:
            self.handle_unmet_dependency(idents, msg['parent_header'])

    @logged
    def handle_result(self, idents, parent, raw_msg, success=True):
        """handle a real task result, either success or failure"""
        # first, relay result to client
        engine = idents[0]
        client = idents[1]
        # swap_ids for XREP-XREP mirror
        raw_msg[:2] = [client, engine]
        # print (map(str, raw_msg[:4]))
        self.client_stream.send_multipart(raw_msg, copy=False)
        # now, update our data structures
        msg_id = parent['msg_id']
        self.blacklist.pop(msg_id, None)
        self.pending[engine].pop(msg_id)
        if success:
            self.completed[engine].add(msg_id)
            self.all_completed.add(msg_id)
        else:
            self.failed[engine].add(msg_id)
            self.all_failed.add(msg_id)
        self.all_done.add(msg_id)
        self.destinations[msg_id] = engine

        self.update_graph(msg_id, success)

    @logged
    def handle_unmet_dependency(self, idents, parent):
        """handle an unmet dependency"""
        engine = idents[0]
        msg_id = parent['msg_id']

        if msg_id not in self.blacklist:
            self.blacklist[msg_id] = set()
        self.blacklist[msg_id].add(engine)

        args = self.pending[engine].pop(msg_id)
        raw, targets, after, follow, timeout = args

        if self.blacklist[msg_id] == targets:
            self.depending[msg_id] = args
            return self.fail_unreachable(msg_id)

        elif not self.maybe_run(msg_id, *args):
            # resubmit failed, put it back in our dependency tree
            self.save_unmet(msg_id, *args)

    @logged
    def update_graph(self, dep_id, success=True):
        """dep_id just finished. Update our dependency
        graph and submit any jobs that just became runable."""
        # print ("\n\n***********")
        # pprint (dep_id)
        # pprint (self.graph)
        # pprint (self.depending)
        # pprint (self.all_completed)
        # pprint (self.all_failed)
        # print ("\n\n***********\n\n")
        if dep_id not in self.graph:
            return
        jobs = self.graph.pop(dep_id)

        for msg_id in jobs:
            raw_msg, targets, after, follow, timeout = self.depending[msg_id]

            if after.unreachable(self.all_completed,
                                 self.all_failed) or follow.unreachable(
                                     self.all_completed, self.all_failed):
                self.fail_unreachable(msg_id)

            elif after.check(self.all_completed,
                             self.all_failed):  # time deps met, maybe run
                if self.maybe_run(msg_id, raw_msg, targets, MET, follow,
                                  timeout):

                    self.depending.pop(msg_id)
                    for mid in follow.union(after):
                        if mid in self.graph:
                            self.graph[mid].remove(msg_id)

    #----------------------------------------------------------------------
    # methods to be overridden by subclasses
    #----------------------------------------------------------------------

    def add_job(self, idx):
        """Called after self.targets[idx] just got the job with header.
        Override with subclasses.  The default ordering is simple LRU.
        The default loads are the number of outstanding jobs."""
        self.loads[idx] += 1
        for lis in (self.targets, self.loads):
            lis.append(lis.pop(idx))

    def finish_job(self, idx):
        """Called after self.targets[idx] just finished a job.
        Override with subclasses."""
        self.loads[idx] -= 1
예제 #11
0
class Client(HasTraits):
    """A semi-synchronous client to the IPython ZMQ cluster
    
    Parameters
    ----------
    
    url_or_file : bytes; zmq url or path to ipcontroller-client.json
        Connection information for the Hub's registration.  If a json connector
        file is given, then likely no further configuration is necessary.
        [Default: use profile]
    profile : bytes
        The name of the Cluster profile to be used to find connector information.
        [Default: 'default']
    context : zmq.Context
        Pass an existing zmq.Context instance, otherwise the client will create its own.
    username : bytes
        set username to be passed to the Session object
    debug : bool
        flag for lots of message printing for debug purposes
 
    #-------------- ssh related args ----------------
    # These are args for configuring the ssh tunnel to be used
    # credentials are used to forward connections over ssh to the Controller
    # Note that the ip given in `addr` needs to be relative to sshserver
    # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
    # and set sshserver as the same machine the Controller is on. However, 
    # the only requirement is that sshserver is able to see the Controller
    # (i.e. is within the same trusted network).
    
    sshserver : str
        A string of the form passed to ssh, i.e. 'server.tld' or '[email protected]:port'
        If keyfile or password is specified, and this is not, it will default to
        the ip given in addr.
    sshkey : str; path to public ssh key file
        This specifies a key to be used in ssh login, default None.
        Regular default ssh keys will be used without specifying this argument.
    password : str 
        Your ssh password to sshserver. Note that if this is left None,
        you will be prompted for it if passwordless key based login is unavailable.
    paramiko : bool
        flag for whether to use paramiko instead of shell ssh for tunneling.
        [default: True on win32, False else]
    
    ------- exec authentication args -------
    If even localhost is untrusted, you can have some protection against
    unauthorized execution by using a key.  Messages are still sent
    as cleartext, so if someone can snoop your loopback traffic this will
    not help against malicious attacks.
    
    exec_key : str
        an authentication key or file containing a key
        default: None
    
    
    Attributes
    ----------
    
    ids : list of int engine IDs
        requesting the ids attribute always synchronizes
        the registration state. To request ids without synchronization,
        use semi-private _ids attributes.
    
    history : list of msg_ids
        a list of msg_ids, keeping track of all the execution
        messages you have submitted in order.
    
    outstanding : set of msg_ids
        a set of msg_ids that have been submitted, but whose
        results have not yet been received.
    
    results : dict
        a dict of all our results, keyed by msg_id
    
    block : bool
        determines default behavior when block not specified
        in execution methods
    
    Methods
    -------
    
    spin
        flushes incoming results and registration state changes
        control methods spin, and requesting `ids` also ensures up to date
    
    wait
        wait on one or more msg_ids
    
    execution methods
        apply
        legacy: execute, run
    
    data movement
        push, pull, scatter, gather
    
    query methods
        queue_status, get_result, purge, result_status
    
    control methods
        abort, shutdown
    
    """

    block = Bool(False)
    outstanding = Set()
    results = Instance('collections.defaultdict', (dict, ))
    metadata = Instance('collections.defaultdict', (Metadata, ))
    history = List()
    debug = Bool(False)
    profile = CUnicode('default')

    _outstanding_dict = Instance('collections.defaultdict', (set, ))
    _ids = List()
    _connected = Bool(False)
    _ssh = Bool(False)
    _context = Instance('zmq.Context')
    _config = Dict()
    _engines = Instance(util.ReverseDict, (), {})
    # _hub_socket=Instance('zmq.Socket')
    _query_socket = Instance('zmq.Socket')
    _control_socket = Instance('zmq.Socket')
    _iopub_socket = Instance('zmq.Socket')
    _notification_socket = Instance('zmq.Socket')
    _mux_socket = Instance('zmq.Socket')
    _task_socket = Instance('zmq.Socket')
    _task_scheme = Str()
    _closed = False
    _ignored_control_replies = Int(0)
    _ignored_hub_replies = Int(0)

    def __init__(self,
                 url_or_file=None,
                 profile='default',
                 cluster_dir=None,
                 ipython_dir=None,
                 context=None,
                 username=None,
                 debug=False,
                 exec_key=None,
                 sshserver=None,
                 sshkey=None,
                 password=None,
                 paramiko=None,
                 timeout=10):
        super(Client, self).__init__(debug=debug, profile=profile)
        if context is None:
            context = zmq.Context.instance()
        self._context = context

        self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
        if self._cd is not None:
            if url_or_file is None:
                url_or_file = pjoin(self._cd.security_dir,
                                    'ipcontroller-client.json')
        assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
            " Please specify at least one of url_or_file or profile."

        try:
            util.validate_url(url_or_file)
        except AssertionError:
            if not os.path.exists(url_or_file):
                if self._cd:
                    url_or_file = os.path.join(self._cd.security_dir,
                                               url_or_file)
                assert os.path.exists(
                    url_or_file
                ), "Not a valid connection file or url: %r" % url_or_file
            with open(url_or_file) as f:
                cfg = json.loads(f.read())
        else:
            cfg = {'url': url_or_file}

        # sync defaults from args, json:
        if sshserver:
            cfg['ssh'] = sshserver
        if exec_key:
            cfg['exec_key'] = exec_key
        exec_key = cfg['exec_key']
        sshserver = cfg['ssh']
        url = cfg['url']
        location = cfg.setdefault('location', None)
        cfg['url'] = util.disambiguate_url(cfg['url'], location)
        url = cfg['url']

        self._config = cfg

        self._ssh = bool(sshserver or sshkey or password)
        if self._ssh and sshserver is None:
            # default to ssh via localhost
            sshserver = url.split('://')[1].split(':')[0]
        if self._ssh and password is None:
            if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
                password = False
            else:
                password = getpass("SSH Password for %s: " % sshserver)
        ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
        if exec_key is not None and os.path.isfile(exec_key):
            arg = 'keyfile'
        else:
            arg = 'key'
        key_arg = {arg: exec_key}
        if username is None:
            self.session = ss.StreamSession(**key_arg)
        else:
            self.session = ss.StreamSession(username, **key_arg)
        self._query_socket = self._context.socket(zmq.XREQ)
        self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
        if self._ssh:
            tunnel.tunnel_connection(self._query_socket, url, sshserver,
                                     **ssh_kwargs)
        else:
            self._query_socket.connect(url)

        self.session.debug = self.debug

        self._notification_handlers = {
            'registration_notification': self._register_engine,
            'unregistration_notification': self._unregister_engine,
            'shutdown_notification': lambda msg: self.close(),
        }
        self._queue_handlers = {
            'execute_reply': self._handle_execute_reply,
            'apply_reply': self._handle_apply_reply
        }
        self._connect(sshserver, ssh_kwargs, timeout)

    def __del__(self):
        """cleanup sockets, but _not_ context."""
        self.close()

    def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
        if ipython_dir is None:
            ipython_dir = get_ipython_dir()
        if cluster_dir is not None:
            try:
                self._cd = ClusterDir.find_cluster_dir(cluster_dir)
                return
            except ClusterDirError:
                pass
        elif profile is not None:
            try:
                self._cd = ClusterDir.find_cluster_dir_by_profile(
                    ipython_dir, profile)
                return
            except ClusterDirError:
                pass
        self._cd = None

    def _update_engines(self, engines):
        """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
        for k, v in engines.iteritems():
            eid = int(k)
            self._engines[eid] = bytes(v)  # force not unicode
            self._ids.append(eid)
        self._ids = sorted(self._ids)
        if sorted(self._engines.keys()) != range(len(self._engines)) and \
                        self._task_scheme == 'pure' and self._task_socket:
            self._stop_scheduling_tasks()

    def _stop_scheduling_tasks(self):
        """Stop scheduling tasks because an engine has been unregistered
        from a pure ZMQ scheduler.
        """
        self._task_socket.close()
        self._task_socket = None
        msg = "An engine has been unregistered, and we are using pure " +\
              "ZMQ task scheduling.  Task farming will be disabled."
        if self.outstanding:
            msg += " If you were running tasks when this happened, " +\
                   "some `outstanding` msg_ids may never resolve."
        warnings.warn(msg, RuntimeWarning)

    def _build_targets(self, targets):
        """Turn valid target IDs or 'all' into two lists:
        (int_ids, uuids).
        """
        if targets is None:
            targets = self._ids
        elif isinstance(targets, str):
            if targets.lower() == 'all':
                targets = self._ids
            else:
                raise TypeError("%r not valid str target, must be 'all'" %
                                (targets))
        elif isinstance(targets, int):
            if targets < 0:
                targets = self.ids[targets]
            if targets not in self.ids:
                raise IndexError("No such engine: %i" % targets)
            targets = [targets]

        if isinstance(targets, slice):
            indices = range(len(self._ids))[targets]
            ids = self.ids
            targets = [ids[i] for i in indices]

        if not isinstance(targets, (tuple, list, xrange)):
            raise TypeError(
                "targets by int/slice/collection of ints only, not %s" %
                (type(targets)))

        return [self._engines[t] for t in targets], list(targets)

    def _connect(self, sshserver, ssh_kwargs, timeout):
        """setup all our socket connections to the cluster. This is called from
        __init__."""

        # Maybe allow reconnecting?
        if self._connected:
            return
        self._connected = True

        def connect_socket(s, url):
            url = util.disambiguate_url(url, self._config['location'])
            if self._ssh:
                return tunnel.tunnel_connection(s, url, sshserver,
                                                **ssh_kwargs)
            else:
                return s.connect(url)

        self.session.send(self._query_socket, 'connection_request')
        r, w, x = zmq.select([self._query_socket], [], [], timeout)
        if not r:
            raise error.TimeoutError("Hub connection request timed out")
        idents, msg = self.session.recv(self._query_socket, mode=0)
        if self.debug:
            pprint(msg)
        msg = ss.Message(msg)
        content = msg.content
        self._config['registration'] = dict(content)
        if content.status == 'ok':
            if content.mux:
                self._mux_socket = self._context.socket(zmq.XREQ)
                self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
                connect_socket(self._mux_socket, content.mux)
            if content.task:
                self._task_scheme, task_addr = content.task
                self._task_socket = self._context.socket(zmq.XREQ)
                self._task_socket.setsockopt(zmq.IDENTITY,
                                             self.session.session)
                connect_socket(self._task_socket, task_addr)
            if content.notification:
                self._notification_socket = self._context.socket(zmq.SUB)
                connect_socket(self._notification_socket, content.notification)
                self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
            # if content.query:
            #     self._query_socket = self._context.socket(zmq.XREQ)
            #     self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
            #     connect_socket(self._query_socket, content.query)
            if content.control:
                self._control_socket = self._context.socket(zmq.XREQ)
                self._control_socket.setsockopt(zmq.IDENTITY,
                                                self.session.session)
                connect_socket(self._control_socket, content.control)
            if content.iopub:
                self._iopub_socket = self._context.socket(zmq.SUB)
                self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
                self._iopub_socket.setsockopt(zmq.IDENTITY,
                                              self.session.session)
                connect_socket(self._iopub_socket, content.iopub)
            self._update_engines(dict(content.engines))
        else:
            self._connected = False
            raise Exception("Failed to connect!")

    #--------------------------------------------------------------------------
    # handlers and callbacks for incoming messages
    #--------------------------------------------------------------------------

    def _unwrap_exception(self, content):
        """unwrap exception, and remap engine_id to int."""
        e = error.unwrap_exception(content)
        # print e.traceback
        if e.engine_info:
            e_uuid = e.engine_info['engine_uuid']
            eid = self._engines[e_uuid]
            e.engine_info['engine_id'] = eid
        return e

    def _extract_metadata(self, header, parent, content):
        md = {
            'msg_id': parent['msg_id'],
            'received': datetime.now(),
            'engine_uuid': header.get('engine', None),
            'follow': parent.get('follow', []),
            'after': parent.get('after', []),
            'status': content['status'],
        }

        if md['engine_uuid'] is not None:
            md['engine_id'] = self._engines.get(md['engine_uuid'], None)

        if 'date' in parent:
            md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
        if 'started' in header:
            md['started'] = datetime.strptime(header['started'], util.ISO8601)
        if 'date' in header:
            md['completed'] = datetime.strptime(header['date'], util.ISO8601)
        return md

    def _register_engine(self, msg):
        """Register a new engine, and update our connection info."""
        content = msg['content']
        eid = content['id']
        d = {eid: content['queue']}
        self._update_engines(d)

    def _unregister_engine(self, msg):
        """Unregister an engine that has died."""
        content = msg['content']
        eid = int(content['id'])
        if eid in self._ids:
            self._ids.remove(eid)
            uuid = self._engines.pop(eid)

            self._handle_stranded_msgs(eid, uuid)

        if self._task_socket and self._task_scheme == 'pure':
            self._stop_scheduling_tasks()

    def _handle_stranded_msgs(self, eid, uuid):
        """Handle messages known to be on an engine when the engine unregisters.
        
        It is possible that this will fire prematurely - that is, an engine will
        go down after completing a result, and the client will be notified
        of the unregistration and later receive the successful result.
        """

        outstanding = self._outstanding_dict[uuid]

        for msg_id in list(outstanding):
            if msg_id in self.results:
                # we already
                continue
            try:
                raise error.EngineError(
                    "Engine %r died while running task %r" % (eid, msg_id))
            except:
                content = error.wrap_exception()
            # build a fake message:
            parent = {}
            header = {}
            parent['msg_id'] = msg_id
            header['engine'] = uuid
            header['date'] = datetime.now().strftime(util.ISO8601)
            msg = dict(parent_header=parent, header=header, content=content)
            self._handle_apply_reply(msg)

    def _handle_execute_reply(self, msg):
        """Save the reply to an execute_request into our results.
        
        execute messages are never actually used. apply is used instead.
        """

        parent = msg['parent_header']
        msg_id = parent['msg_id']
        if msg_id not in self.outstanding:
            if msg_id in self.history:
                print("got stale result: %s" % msg_id)
            else:
                print("got unknown result: %s" % msg_id)
        else:
            self.outstanding.remove(msg_id)
        self.results[msg_id] = self._unwrap_exception(msg['content'])

    def _handle_apply_reply(self, msg):
        """Save the reply to an apply_request into our results."""
        parent = msg['parent_header']
        msg_id = parent['msg_id']
        if msg_id not in self.outstanding:
            if msg_id in self.history:
                print("got stale result: %s" % msg_id)
                print self.results[msg_id]
                print msg
            else:
                print("got unknown result: %s" % msg_id)
        else:
            self.outstanding.remove(msg_id)
        content = msg['content']
        header = msg['header']

        # construct metadata:
        md = self.metadata[msg_id]
        md.update(self._extract_metadata(header, parent, content))
        # is this redundant?
        self.metadata[msg_id] = md

        e_outstanding = self._outstanding_dict[md['engine_uuid']]
        if msg_id in e_outstanding:
            e_outstanding.remove(msg_id)

        # construct result:
        if content['status'] == 'ok':
            self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
        elif content['status'] == 'aborted':
            self.results[msg_id] = error.TaskAborted(msg_id)
        elif content['status'] == 'resubmitted':
            # TODO: handle resubmission
            pass
        else:
            self.results[msg_id] = self._unwrap_exception(content)

    def _flush_notifications(self):
        """Flush notifications of engine registrations waiting
        in ZMQ queue."""
        msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
        while msg is not None:
            if self.debug:
                pprint(msg)
            msg = msg[-1]
            msg_type = msg['msg_type']
            handler = self._notification_handlers.get(msg_type, None)
            if handler is None:
                raise Exception("Unhandled message type: %s" % msg.msg_type)
            else:
                handler(msg)
            msg = self.session.recv(self._notification_socket,
                                    mode=zmq.NOBLOCK)

    def _flush_results(self, sock):
        """Flush task or queue results waiting in ZMQ queue."""
        msg = self.session.recv(sock, mode=zmq.NOBLOCK)
        while msg is not None:
            if self.debug:
                pprint(msg)
            msg = msg[-1]
            msg_type = msg['msg_type']
            handler = self._queue_handlers.get(msg_type, None)
            if handler is None:
                raise Exception("Unhandled message type: %s" % msg.msg_type)
            else:
                handler(msg)
            msg = self.session.recv(sock, mode=zmq.NOBLOCK)

    def _flush_control(self, sock):
        """Flush replies from the control channel waiting
        in the ZMQ queue.
        
        Currently: ignore them."""
        if self._ignored_control_replies <= 0:
            return
        msg = self.session.recv(sock, mode=zmq.NOBLOCK)
        while msg is not None:
            self._ignored_control_replies -= 1
            if self.debug:
                pprint(msg)
            msg = self.session.recv(sock, mode=zmq.NOBLOCK)

    def _flush_ignored_control(self):
        """flush ignored control replies"""
        while self._ignored_control_replies > 0:
            self.session.recv(self._control_socket)
            self._ignored_control_replies -= 1

    def _flush_ignored_hub_replies(self):
        msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
        while msg is not None:
            msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)

    def _flush_iopub(self, sock):
        """Flush replies from the iopub channel waiting
        in the ZMQ queue.
        """
        msg = self.session.recv(sock, mode=zmq.NOBLOCK)
        while msg is not None:
            if self.debug:
                pprint(msg)
            msg = msg[-1]
            parent = msg['parent_header']
            msg_id = parent['msg_id']
            content = msg['content']
            header = msg['header']
            msg_type = msg['msg_type']

            # init metadata:
            md = self.metadata[msg_id]

            if msg_type == 'stream':
                name = content['name']
                s = md[name] or ''
                md[name] = s + content['data']
            elif msg_type == 'pyerr':
                md.update({'pyerr': self._unwrap_exception(content)})
            elif msg_type == 'pyin':
                md.update({'pyin': content['code']})
            else:
                md.update({msg_type: content.get('data', '')})

            # reduntant?
            self.metadata[msg_id] = md

            msg = self.session.recv(sock, mode=zmq.NOBLOCK)

    #--------------------------------------------------------------------------
    # len, getitem
    #--------------------------------------------------------------------------

    def __len__(self):
        """len(client) returns # of engines."""
        return len(self.ids)

    def __getitem__(self, key):
        """index access returns DirectView multiplexer objects
        
        Must be int, slice, or list/tuple/xrange of ints"""
        if not isinstance(key, (int, slice, tuple, list, xrange)):
            raise TypeError("key by int/slice/iterable of ints only, not %s" %
                            (type(key)))
        else:
            return self.direct_view(key)

    #--------------------------------------------------------------------------
    # Begin public methods
    #--------------------------------------------------------------------------

    @property
    def ids(self):
        """Always up-to-date ids property."""
        self._flush_notifications()
        # always copy:
        return list(self._ids)

    def close(self):
        if self._closed:
            return
        snames = filter(lambda n: n.endswith('socket'), dir(self))
        for socket in map(lambda name: getattr(self, name), snames):
            if isinstance(socket, zmq.Socket) and not socket.closed:
                socket.close()
        self._closed = True

    def spin(self):
        """Flush any registration notifications and execution results
        waiting in the ZMQ queue.
        """
        if self._notification_socket:
            self._flush_notifications()
        if self._mux_socket:
            self._flush_results(self._mux_socket)
        if self._task_socket:
            self._flush_results(self._task_socket)
        if self._control_socket:
            self._flush_control(self._control_socket)
        if self._iopub_socket:
            self._flush_iopub(self._iopub_socket)
        if self._query_socket:
            self._flush_ignored_hub_replies()

    def wait(self, jobs=None, timeout=-1):
        """waits on one or more `jobs`, for up to `timeout` seconds.
        
        Parameters
        ----------
        
        jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
                ints are indices to self.history
                strs are msg_ids
                default: wait on all outstanding messages
        timeout : float
                a time in seconds, after which to give up.
                default is -1, which means no timeout
        
        Returns
        -------
        
        True : when all msg_ids are done
        False : timeout reached, some msg_ids still outstanding
        """
        tic = time.time()
        if jobs is None:
            theids = self.outstanding
        else:
            if isinstance(jobs, (int, str, AsyncResult)):
                jobs = [jobs]
            theids = set()
            for job in jobs:
                if isinstance(job, int):
                    # index access
                    job = self.history[job]
                elif isinstance(job, AsyncResult):
                    map(theids.add, job.msg_ids)
                    continue
                theids.add(job)
        if not theids.intersection(self.outstanding):
            return True
        self.spin()
        while theids.intersection(self.outstanding):
            if timeout >= 0 and (time.time() - tic) > timeout:
                break
            time.sleep(1e-3)
            self.spin()
        return len(theids.intersection(self.outstanding)) == 0

    #--------------------------------------------------------------------------
    # Control methods
    #--------------------------------------------------------------------------

    @spin_first
    @default_block
    def clear(self, targets=None, block=None):
        """Clear the namespace in target(s)."""
        targets = self._build_targets(targets)[0]
        for t in targets:
            self.session.send(self._control_socket,
                              'clear_request',
                              content={},
                              ident=t)
        error = False
        if self.block:
            self._flush_ignored_control()
            for i in range(len(targets)):
                idents, msg = self.session.recv(self._control_socket, 0)
                if self.debug:
                    pprint(msg)
                if msg['content']['status'] != 'ok':
                    error = self._unwrap_exception(msg['content'])
        else:
            self._ignored_control_replies += len(targets)
        if error:
            raise error

    @spin_first
    @default_block
    def abort(self, jobs=None, targets=None, block=None):
        """Abort specific jobs from the execution queues of target(s).
        
        This is a mechanism to prevent jobs that have already been submitted
        from executing.
        
        Parameters
        ----------
        
        jobs : msg_id, list of msg_ids, or AsyncResult
            The jobs to be aborted
        
        
        """
        targets = self._build_targets(targets)[0]
        msg_ids = []
        if isinstance(jobs, (basestring, AsyncResult)):
            jobs = [jobs]
        bad_ids = filter(
            lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
        if bad_ids:
            raise TypeError(
                "Invalid msg_id type %r, expected str or AsyncResult" %
                bad_ids[0])
        for j in jobs:
            if isinstance(j, AsyncResult):
                msg_ids.extend(j.msg_ids)
            else:
                msg_ids.append(j)
        content = dict(msg_ids=msg_ids)
        for t in targets:
            self.session.send(self._control_socket,
                              'abort_request',
                              content=content,
                              ident=t)
        error = False
        if self.block:
            self._flush_ignored_control()
            for i in range(len(targets)):
                idents, msg = self.session.recv(self._control_socket, 0)
                if self.debug:
                    pprint(msg)
                if msg['content']['status'] != 'ok':
                    error = self._unwrap_exception(msg['content'])
        else:
            self._ignored_control_replies += len(targets)
        if error:
            raise error

    @spin_first
    @default_block
    def shutdown(self, targets=None, restart=False, hub=False, block=None):
        """Terminates one or more engine processes, optionally including the hub."""
        if hub:
            targets = 'all'
        targets = self._build_targets(targets)[0]
        for t in targets:
            self.session.send(self._control_socket,
                              'shutdown_request',
                              content={'restart': restart},
                              ident=t)
        error = False
        if block or hub:
            self._flush_ignored_control()
            for i in range(len(targets)):
                idents, msg = self.session.recv(self._control_socket, 0)
                if self.debug:
                    pprint(msg)
                if msg['content']['status'] != 'ok':
                    error = self._unwrap_exception(msg['content'])
        else:
            self._ignored_control_replies += len(targets)

        if hub:
            time.sleep(0.25)
            self.session.send(self._query_socket, 'shutdown_request')
            idents, msg = self.session.recv(self._query_socket, 0)
            if self.debug:
                pprint(msg)
            if msg['content']['status'] != 'ok':
                error = self._unwrap_exception(msg['content'])

        if error:
            raise error

    #--------------------------------------------------------------------------
    # Execution methods
    #--------------------------------------------------------------------------

    @default_block
    def _execute(self, code, targets='all', block=None):
        """Executes `code` on `targets` in blocking or nonblocking manner.
        
        ``execute`` is always `bound` (affects engine namespace)
        
        Parameters
        ----------
        
        code : str
                the code string to be executed
        targets : int/str/list of ints/strs
                the engines on which to execute
                default : all
        block : bool
                whether or not to wait until done to return
                default: self.block
        """
        return self[targets].execute(code, block=block)

    def _maybe_raise(self, result):
        """wrapper for maybe raising an exception if apply failed."""
        if isinstance(result, error.RemoteError):
            raise result

        return result

    def send_apply_message(self,
                           socket,
                           f,
                           args=None,
                           kwargs=None,
                           subheader=None,
                           track=False,
                           ident=None):
        """construct and send an apply message via a socket.
        
        This is the principal method with which all engine execution is performed by views.
        """

        assert not self._closed, "cannot use me anymore, I'm closed!"
        # defaults:
        args = args if args is not None else []
        kwargs = kwargs if kwargs is not None else {}
        subheader = subheader if subheader is not None else {}

        # validate arguments
        if not callable(f):
            raise TypeError("f must be callable, not %s" % type(f))
        if not isinstance(args, (tuple, list)):
            raise TypeError("args must be tuple or list, not %s" % type(args))
        if not isinstance(kwargs, dict):
            raise TypeError("kwargs must be dict, not %s" % type(kwargs))
        if not isinstance(subheader, dict):
            raise TypeError("subheader must be dict, not %s" % type(subheader))

        if not self._ids:
            # flush notification socket if no engines yet
            any_ids = self.ids
            if not any_ids:
                raise error.NoEnginesRegistered(
                    "Can't execute without any connected engines.")
                # enforce types of f,args,kwargs

        bufs = util.pack_apply_message(f, args, kwargs)

        msg = self.session.send(socket,
                                "apply_request",
                                buffers=bufs,
                                ident=ident,
                                subheader=subheader,
                                track=track)

        msg_id = msg['msg_id']
        self.outstanding.add(msg_id)
        if ident:
            # possibly routed to a specific engine
            if isinstance(ident, list):
                ident = ident[-1]
            if ident in self._engines.values():
                # save for later, in case of engine death
                self._outstanding_dict[ident].add(msg_id)
        self.history.append(msg_id)
        self.metadata[msg_id]['submitted'] = datetime.now()

        return msg

    #--------------------------------------------------------------------------
    # construct a View object
    #--------------------------------------------------------------------------

    def load_balanced_view(self, targets=None):
        """construct a DirectView object.
        
        If no arguments are specified, create a LoadBalancedView
        using all engines.
        
        Parameters
        ----------
        
        targets: list,slice,int,etc. [default: use all engines]
            The subset of engines across which to load-balance
        """
        if targets is not None:
            targets = self._build_targets(targets)[1]
        return LoadBalancedView(client=self,
                                socket=self._task_socket,
                                targets=targets)

    def direct_view(self, targets='all'):
        """construct a DirectView object.
        
        If no targets are specified, create a DirectView
        using all engines.
        
        Parameters
        ----------
        
        targets: list,slice,int,etc. [default: use all engines]
            The engines to use for the View
        """
        single = isinstance(targets, int)
        targets = self._build_targets(targets)[1]
        if single:
            targets = targets[0]
        return DirectView(client=self,
                          socket=self._mux_socket,
                          targets=targets)

    #--------------------------------------------------------------------------
    # Data movement (TO BE REMOVED)
    #--------------------------------------------------------------------------

    @default_block
    def _push(self, ns, targets='all', block=None, track=False):
        """Push the contents of `ns` into the namespace on `target`"""
        if not isinstance(ns, dict):
            raise TypeError("Must be a dict, not %s" % type(ns))
        result = self.apply(util._push,
                            kwargs=ns,
                            targets=targets,
                            block=block,
                            bound=True,
                            balanced=False,
                            track=track)
        if not block:
            return result

    @default_block
    def _pull(self, keys, targets='all', block=None):
        """Pull objects from `target`'s namespace by `keys`"""
        if isinstance(keys, basestring):
            pass
        elif isinstance(keys, (list, tuple, set)):
            for key in keys:
                if not isinstance(key, basestring):
                    raise TypeError("keys must be str, not type %r" %
                                    type(key))
        else:
            raise TypeError("keys must be strs, not %r" % keys)
        result = self.apply(util._pull, (keys, ),
                            targets=targets,
                            block=block,
                            bound=True,
                            balanced=False)
        return result

    #--------------------------------------------------------------------------
    # Query methods
    #--------------------------------------------------------------------------

    @spin_first
    @default_block
    def get_result(self, indices_or_msg_ids=None, block=None):
        """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
        
        If the client already has the results, no request to the Hub will be made.
        
        This is a convenient way to construct AsyncResult objects, which are wrappers
        that include metadata about execution, and allow for awaiting results that
        were not submitted by this Client.
        
        It can also be a convenient way to retrieve the metadata associated with
        blocking execution, since it always retrieves
        
        Examples
        --------
        ::
        
            In [10]: r = client.apply()
        
        Parameters
        ----------
        
        indices_or_msg_ids : integer history index, str msg_id, or list of either
            The indices or msg_ids of indices to be retrieved
        
        block : bool
            Whether to wait for the result to be done
        
        Returns
        -------
        
        AsyncResult
            A single AsyncResult object will always be returned.
        
        AsyncHubResult
            A subclass of AsyncResult that retrieves results from the Hub
        
        """
        if indices_or_msg_ids is None:
            indices_or_msg_ids = -1

        if not isinstance(indices_or_msg_ids, (list, tuple)):
            indices_or_msg_ids = [indices_or_msg_ids]

        theids = []
        for id in indices_or_msg_ids:
            if isinstance(id, int):
                id = self.history[id]
            if not isinstance(id, str):
                raise TypeError("indices must be str or int, not %r" % id)
            theids.append(id)

        local_ids = filter(
            lambda msg_id: msg_id in self.history or msg_id in self.results,
            theids)
        remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)

        if remote_ids:
            ar = AsyncHubResult(self, msg_ids=theids)
        else:
            ar = AsyncResult(self, msg_ids=theids)

        if block:
            ar.wait()

        return ar

    @spin_first
    def result_status(self, msg_ids, status_only=True):
        """Check on the status of the result(s) of the apply request with `msg_ids`.
        
        If status_only is False, then the actual results will be retrieved, else
        only the status of the results will be checked.
        
        Parameters
        ----------
        
        msg_ids : list of msg_ids
            if int:
                Passed as index to self.history for convenience.
        status_only : bool (default: True)
            if False:
                Retrieve the actual results of completed tasks.
        
        Returns
        -------
        
        results : dict
            There will always be the keys 'pending' and 'completed', which will
            be lists of msg_ids that are incomplete or complete. If `status_only`
            is False, then completed results will be keyed by their `msg_id`.
        """
        if not isinstance(msg_ids, (list, tuple)):
            msg_ids = [msg_ids]

        theids = []
        for msg_id in msg_ids:
            if isinstance(msg_id, int):
                msg_id = self.history[msg_id]
            if not isinstance(msg_id, basestring):
                raise TypeError("msg_ids must be str, not %r" % msg_id)
            theids.append(msg_id)

        completed = []
        local_results = {}

        # comment this block out to temporarily disable local shortcut:
        for msg_id in theids:
            if msg_id in self.results:
                completed.append(msg_id)
                local_results[msg_id] = self.results[msg_id]
                theids.remove(msg_id)

        if theids:  # some not locally cached
            content = dict(msg_ids=theids, status_only=status_only)
            msg = self.session.send(self._query_socket,
                                    "result_request",
                                    content=content)
            zmq.select([self._query_socket], [], [])
            idents, msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
            if self.debug:
                pprint(msg)
            content = msg['content']
            if content['status'] != 'ok':
                raise self._unwrap_exception(content)
            buffers = msg['buffers']
        else:
            content = dict(completed=[], pending=[])

        content['completed'].extend(completed)

        if status_only:
            return content

        failures = []
        # load cached results into result:
        content.update(local_results)
        # update cache with results:
        for msg_id in sorted(theids):
            if msg_id in content['completed']:
                rec = content[msg_id]
                parent = rec['header']
                header = rec['result_header']
                rcontent = rec['result_content']
                iodict = rec['io']
                if isinstance(rcontent, str):
                    rcontent = self.session.unpack(rcontent)

                md = self.metadata[msg_id]
                md.update(self._extract_metadata(header, parent, rcontent))
                md.update(iodict)

                if rcontent['status'] == 'ok':
                    res, buffers = util.unserialize_object(buffers)
                else:
                    print rcontent
                    res = self._unwrap_exception(rcontent)
                    failures.append(res)

                self.results[msg_id] = res
                content[msg_id] = res

        if len(theids) == 1 and failures:
            raise failures[0]

        error.collect_exceptions(failures, "result_status")
        return content

    @spin_first
    def queue_status(self, targets='all', verbose=False):
        """Fetch the status of engine queues.
        
        Parameters
        ----------
        
        targets : int/str/list of ints/strs
                the engines whose states are to be queried.
                default : all
        verbose : bool
                Whether to return lengths only, or lists of ids for each element
        """
        engine_ids = self._build_targets(targets)[1]
        content = dict(targets=engine_ids, verbose=verbose)
        self.session.send(self._query_socket, "queue_request", content=content)
        idents, msg = self.session.recv(self._query_socket, 0)
        if self.debug:
            pprint(msg)
        content = msg['content']
        status = content.pop('status')
        if status != 'ok':
            raise self._unwrap_exception(content)
        content = util.rekey(content)
        if isinstance(targets, int):
            return content[targets]
        else:
            return content

    @spin_first
    def purge_results(self, jobs=[], targets=[]):
        """Tell the Hub to forget results.
        
        Individual results can be purged by msg_id, or the entire
        history of specific targets can be purged.
        
        Parameters
        ----------
        
        jobs : str or list of str or AsyncResult objects
                the msg_ids whose results should be forgotten.
        targets : int/str/list of ints/strs
                The targets, by uuid or int_id, whose entire history is to be purged.
                Use `targets='all'` to scrub everything from the Hub's memory.
                
                default : None
        """
        if not targets and not jobs:
            raise ValueError(
                "Must specify at least one of `targets` and `jobs`")
        if targets:
            targets = self._build_targets(targets)[1]

        # construct msg_ids from jobs
        msg_ids = []
        if isinstance(jobs, (basestring, AsyncResult)):
            jobs = [jobs]
        bad_ids = filter(
            lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
        if bad_ids:
            raise TypeError(
                "Invalid msg_id type %r, expected str or AsyncResult" %
                bad_ids[0])
        for j in jobs:
            if isinstance(j, AsyncResult):
                msg_ids.extend(j.msg_ids)
            else:
                msg_ids.append(j)

        content = dict(targets=targets, msg_ids=msg_ids)
        self.session.send(self._query_socket, "purge_request", content=content)
        idents, msg = self.session.recv(self._query_socket, 0)
        if self.debug:
            pprint(msg)
        content = msg['content']
        if content['status'] != 'ok':
            raise self._unwrap_exception(content)
예제 #12
0
class ExtractOutputPreprocessor(Preprocessor):
    """
    Extracts all of the outputs from the notebook file.  The extracted 
    outputs are returned in the 'resources' dictionary.
    """

    output_filename_template = Unicode(
        "{unique_key}_{cell_index}_{index}{extension}", config=True)

    extract_output_types = Set({'png', 'jpeg', 'svg', 'application/pdf'},
                               config=True)

    def preprocess_cell(self, cell, resources, cell_index):
        """
        Apply a transformation on each cell,
        
        Parameters
        ----------
        cell : NotebookNode cell
            Notebook cell being processed
        resources : dictionary
            Additional resources used in the conversion process.  Allows
            preprocessors to pass variables into the Jinja engine.
        cell_index : int
            Index of the cell being processed (see base.py)
        """

        #Get the unique key from the resource dict if it exists.  If it does not
        #exist, use 'output' as the default.  Also, get files directory if it
        #has been specified
        unique_key = resources.get('unique_key', 'output')
        output_files_dir = resources.get('output_files_dir', None)

        #Make sure outputs key exists
        if not isinstance(resources['outputs'], dict):
            resources['outputs'] = {}

        #Loop through all of the outputs in the cell
        for index, out in enumerate(cell.get('outputs', [])):

            #Get the output in data formats that the template needs extracted
            for out_type in self.extract_output_types:
                if out_type in out:
                    data = out[out_type]

                    #Binary files are base64-encoded, SVG is already XML
                    if out_type in {'png', 'jpeg', 'application/pdf'}:

                        # data is b64-encoded as text (str, unicode)
                        # decodestring only accepts bytes
                        data = py3compat.cast_bytes(data)
                        data = base64.decodestring(data)
                    elif sys.platform == 'win32':
                        data = data.replace('\n', '\r\n').encode("UTF-8")
                    else:
                        data = data.encode("UTF-8")

                    # Build an output name
                    # filthy hack while we have some mimetype output, and some not
                    if '/' in out_type:
                        ext = guess_extension(out_type)
                        if ext is None:
                            ext = '.' + out_type.rsplit('/')[-1]
                    else:
                        ext = '.' + out_type

                    filename = self.output_filename_template.format(
                        unique_key=unique_key,
                        cell_index=cell_index,
                        index=index,
                        extension=ext)

                    #On the cell, make the figure available via
                    #   cell.outputs[i].svg_filename  ... etc (svg in example)
                    # Where
                    #   cell.outputs[i].svg  contains the data
                    if output_files_dir is not None:
                        filename = os.path.join(output_files_dir, filename)
                    out[out_type + '_filename'] = filename

                    #In the resources, make the figure available via
                    #   resources['outputs']['filename'] = data
                    resources['outputs'][filename] = data

        return cell, resources
예제 #13
0
class Kernel(SessionFactory):

    #---------------------------------------------------------------------------
    # Kernel interface
    #---------------------------------------------------------------------------

    # kwargs:
    exec_lines = List(Unicode, config=True, help="List of lines to execute")

    # identities:
    int_id = Int(-1)
    bident = CBytes()
    ident = Unicode()

    def _ident_changed(self, name, old, new):
        self.bident = asbytes(new)

    user_ns = Dict(config=True,
                   help="""Set the user's namespace of the Kernel""")

    control_stream = Instance(zmqstream.ZMQStream)
    task_stream = Instance(zmqstream.ZMQStream)
    iopub_stream = Instance(zmqstream.ZMQStream)
    client = Instance('IPython.parallel.Client')

    # internals
    shell_streams = List()
    compiler = Instance(CommandCompiler, (), {})
    completer = Instance(KernelCompleter)

    aborted = Set()
    shell_handlers = Dict()
    control_handlers = Dict()

    def _set_prefix(self):
        self.prefix = "engine.%s" % self.int_id

    def _connect_completer(self):
        self.completer = KernelCompleter(self.user_ns)

    def __init__(self, **kwargs):
        super(Kernel, self).__init__(**kwargs)
        self._set_prefix()
        self._connect_completer()

        self.on_trait_change(self._set_prefix, 'id')
        self.on_trait_change(self._connect_completer, 'user_ns')

        # Build dict of handlers for message types
        for msg_type in [
                'execute_request', 'complete_request', 'apply_request',
                'clear_request'
        ]:
            self.shell_handlers[msg_type] = getattr(self, msg_type)

        for msg_type in ['shutdown_request', 'abort_request'
                         ] + self.shell_handlers.keys():
            self.control_handlers[msg_type] = getattr(self, msg_type)

        self._initial_exec_lines()

    def _wrap_exception(self, method=None):
        e_info = dict(engine_uuid=self.ident,
                      engine_id=self.int_id,
                      method=method)
        content = wrap_exception(e_info)
        return content

    def _initial_exec_lines(self):
        s = _Passer()
        content = dict(silent=True, user_variable=[], user_expressions=[])
        for line in self.exec_lines:
            self.log.debug("executing initialization: %s" % line)
            content.update({'code': line})
            msg = self.session.msg('execute_request', content)
            self.execute_request(s, [], msg)

    #-------------------- control handlers -----------------------------
    def abort_queues(self):
        for stream in self.shell_streams:
            if stream:
                self.abort_queue(stream)

    def abort_queue(self, stream):
        while True:
            idents, msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
            if msg is None:
                return

            self.log.info("Aborting:")
            self.log.info(str(msg))
            msg_type = msg['header']['msg_type']
            reply_type = msg_type.split('_')[0] + '_reply'
            # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
            # self.reply_socket.send(ident,zmq.SNDMORE)
            # self.reply_socket.send_json(reply_msg)
            reply_msg = self.session.send(stream,
                                          reply_type,
                                          content={'status': 'aborted'},
                                          parent=msg,
                                          ident=idents)
            self.log.debug(str(reply_msg))
            # We need to wait a bit for requests to come in. This can probably
            # be set shorter for true asynchronous clients.
            time.sleep(0.05)

    def abort_request(self, stream, ident, parent):
        """abort a specifig msg by id"""
        msg_ids = parent['content'].get('msg_ids', None)
        if isinstance(msg_ids, basestring):
            msg_ids = [msg_ids]
        if not msg_ids:
            self.abort_queues()
        for mid in msg_ids:
            self.aborted.add(str(mid))

        content = dict(status='ok')
        reply_msg = self.session.send(stream,
                                      'abort_reply',
                                      content=content,
                                      parent=parent,
                                      ident=ident)
        self.log.debug(str(reply_msg))

    def shutdown_request(self, stream, ident, parent):
        """kill ourself.  This should really be handled in an external process"""
        try:
            self.abort_queues()
        except:
            content = self._wrap_exception('shutdown')
        else:
            content = dict(parent['content'])
            content['status'] = 'ok'
        msg = self.session.send(stream,
                                'shutdown_reply',
                                content=content,
                                parent=parent,
                                ident=ident)
        self.log.debug(str(msg))
        dc = ioloop.DelayedCallback(lambda: sys.exit(0), 1000, self.loop)
        dc.start()

    def dispatch_control(self, msg):
        idents, msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.unserialize(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Message", exc_info=True)
            return
        else:
            self.log.debug("Control received, %s", msg)

        header = msg['header']
        msg_id = header['msg_id']
        msg_type = header['msg_type']

        handler = self.control_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r" % msg_type)
        else:
            handler(self.control_stream, idents, msg)

    #-------------------- queue helpers ------------------------------

    def check_dependencies(self, dependencies):
        if not dependencies:
            return True
        if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
            anyorall = dependencies[0]
            dependencies = dependencies[1]
        else:
            anyorall = 'all'
        results = self.client.get_results(dependencies, status_only=True)
        if results['status'] != 'ok':
            return False

        if anyorall == 'any':
            if not results['completed']:
                return False
        else:
            if results['pending']:
                return False

        return True

    def check_aborted(self, msg_id):
        return msg_id in self.aborted

    #-------------------- queue handlers -----------------------------

    def clear_request(self, stream, idents, parent):
        """Clear our namespace."""
        self.user_ns = {}
        msg = self.session.send(stream,
                                'clear_reply',
                                ident=idents,
                                parent=parent,
                                content=dict(status='ok'))
        self._initial_exec_lines()

    def execute_request(self, stream, ident, parent):
        self.log.debug('execute request %s' % parent)
        try:
            code = parent[u'content'][u'code']
        except:
            self.log.error("Got bad msg: %s" % parent, exc_info=True)
            return
        self.session.send(self.iopub_stream,
                          u'pyin', {u'code': code},
                          parent=parent,
                          ident=asbytes('%s.pyin' % self.prefix))
        started = datetime.now()
        try:
            comp_code = self.compiler(code, '<zmq-kernel>')
            # allow for not overriding displayhook
            if hasattr(sys.displayhook, 'set_parent'):
                sys.displayhook.set_parent(parent)
                sys.stdout.set_parent(parent)
                sys.stderr.set_parent(parent)
            exec comp_code in self.user_ns, self.user_ns
        except:
            exc_content = self._wrap_exception('execute')
            # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
            self.session.send(self.iopub_stream,
                              u'pyerr',
                              exc_content,
                              parent=parent,
                              ident=asbytes('%s.pyerr' % self.prefix))
            reply_content = exc_content
        else:
            reply_content = {'status': 'ok'}

        reply_msg = self.session.send(stream,
                                      u'execute_reply',
                                      reply_content,
                                      parent=parent,
                                      ident=ident,
                                      subheader=dict(started=started))
        self.log.debug(str(reply_msg))
        if reply_msg['content']['status'] == u'error':
            self.abort_queues()

    def complete_request(self, stream, ident, parent):
        matches = {'matches': self.complete(parent), 'status': 'ok'}
        completion_msg = self.session.send(stream, 'complete_reply', matches,
                                           parent, ident)
        # print >> sys.__stdout__, completion_msg

    def complete(self, msg):
        return self.completer.complete(msg.content.line, msg.content.text)

    def apply_request(self, stream, ident, parent):
        # flush previous reply, so this request won't block it
        stream.flush(zmq.POLLOUT)
        try:
            content = parent[u'content']
            bufs = parent[u'buffers']
            msg_id = parent['header']['msg_id']
            # bound = parent['header'].get('bound', False)
        except:
            self.log.error("Got bad msg: %s" % parent, exc_info=True)
            return
        # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
        # self.iopub_stream.send(pyin_msg)
        # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
        sub = {
            'dependencies_met': True,
            'engine': self.ident,
            'started': datetime.now()
        }
        try:
            # allow for not overriding displayhook
            if hasattr(sys.displayhook, 'set_parent'):
                sys.displayhook.set_parent(parent)
                sys.stdout.set_parent(parent)
                sys.stderr.set_parent(parent)
            # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
            working = self.user_ns
            # suffix =
            prefix = "_" + str(msg_id).replace("-", "") + "_"

            f, args, kwargs = unpack_apply_message(bufs, working, copy=False)
            # if bound:
            #     bound_ns = Namespace(working)
            #     args = [bound_ns]+list(args)

            fname = getattr(f, '__name__', 'f')

            fname = prefix + "f"
            argname = prefix + "args"
            kwargname = prefix + "kwargs"
            resultname = prefix + "result"

            ns = {fname: f, argname: args, kwargname: kwargs, resultname: None}
            # print ns
            working.update(ns)
            code = "%s=%s(*%s,**%s)" % (resultname, fname, argname, kwargname)
            try:
                exec code in working, working
                result = working.get(resultname)
            finally:
                for key in ns.iterkeys():
                    working.pop(key)
            # if bound:
            #     working.update(bound_ns)

            packed_result, buf = serialize_object(result)
            result_buf = [packed_result] + buf
        except:
            exc_content = self._wrap_exception('apply')
            # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
            self.session.send(self.iopub_stream,
                              u'pyerr',
                              exc_content,
                              parent=parent,
                              ident=asbytes('%s.pyerr' % self.prefix))
            reply_content = exc_content
            result_buf = []

            if exc_content['ename'] == 'UnmetDependency':
                sub['dependencies_met'] = False
        else:
            reply_content = {'status': 'ok'}

        # put 'ok'/'error' status in header, for scheduler introspection:
        sub['status'] = reply_content['status']

        reply_msg = self.session.send(stream,
                                      u'apply_reply',
                                      reply_content,
                                      parent=parent,
                                      ident=ident,
                                      buffers=result_buf,
                                      subheader=sub)

        # flush i/o
        # should this be before reply_msg is sent, like in the single-kernel code,
        # or should nothing get in the way of real results?
        sys.stdout.flush()
        sys.stderr.flush()

    def dispatch_queue(self, stream, msg):
        self.control_stream.flush()
        idents, msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.unserialize(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Message", exc_info=True)
            return
        else:
            self.log.debug("Message received, %s", msg)

        header = msg['header']
        msg_id = header['msg_id']
        msg_type = msg['header']['msg_type']
        if self.check_aborted(msg_id):
            self.aborted.remove(msg_id)
            # is it safe to assume a msg_id will not be resubmitted?
            reply_type = msg_type.split('_')[0] + '_reply'
            status = {'status': 'aborted'}
            reply_msg = self.session.send(stream,
                                          reply_type,
                                          subheader=status,
                                          content=status,
                                          parent=msg,
                                          ident=idents)
            return
        handler = self.shell_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("UNKNOWN MESSAGE TYPE: %r" % msg_type)
        else:
            handler(stream, idents, msg)

    def start(self):
        #### stream mode:
        if self.control_stream:
            self.control_stream.on_recv(self.dispatch_control, copy=False)
            self.control_stream.on_err(printer)

        def make_dispatcher(stream):
            def dispatcher(msg):
                return self.dispatch_queue(stream, msg)

            return dispatcher

        for s in self.shell_streams:
            s.on_recv(make_dispatcher(s), copy=False)
            s.on_err(printer)

        if self.iopub_stream:
            self.iopub_stream.on_err(printer)
예제 #14
0
파일: hub.py 프로젝트: sherjilozair/ipython
class Hub(LoggingFactory):
    """The IPython Controller Hub with 0MQ connections
    
    Parameters
    ==========
    loop: zmq IOLoop instance
    session: StreamSession object
    <removed> context: zmq context for creating new connections (?)
    queue: ZMQStream for monitoring the command queue (SUB)
    query: ZMQStream for engine registration and client queries requests (XREP)
    heartbeat: HeartMonitor object checking the pulse of the engines
    notifier: ZMQStream for broadcasting engine registration changes (PUB)
    db: connection to db for out of memory logging of commands
                NotImplemented
    engine_info: dict of zmq connection information for engines to connect
                to the queues.
    client_info: dict of zmq connection information for engines to connect
                to the queues.
    """
    # internal data structures:
    ids = Set()  # engine IDs
    keytable = Dict()
    by_ident = Dict()
    engines = Dict()
    clients = Dict()
    hearts = Dict()
    pending = Set()
    queues = Dict()  # pending msg_ids keyed by engine_id
    tasks = Dict()  # pending msg_ids submitted as tasks, keyed by client_id
    completed = Dict()  # completed msg_ids keyed by engine_id
    all_completed = Set()  # completed msg_ids keyed by engine_id
    dead_engines = Set()  # completed msg_ids keyed by engine_id
    unassigned = Set()  # set of task msg_ds not yet assigned a destination
    incoming_registrations = Dict()
    registration_timeout = Int()
    _idcounter = Int(0)

    # objects from constructor:
    loop = Instance(ioloop.IOLoop)
    query = Instance(ZMQStream)
    monitor = Instance(ZMQStream)
    heartmonitor = Instance(HeartMonitor)
    notifier = Instance(ZMQStream)
    db = Instance(object)
    client_info = Dict()
    engine_info = Dict()

    def __init__(self, **kwargs):
        """
        # universal:
        loop: IOLoop for creating future connections
        session: streamsession for sending serialized data
        # engine:
        queue: ZMQStream for monitoring queue messages
        query: ZMQStream for engine+client registration and client requests
        heartbeat: HeartMonitor object for tracking engines
        # extra:
        db: ZMQStream for db connection (NotImplemented)
        engine_info: zmq address/protocol dict for engine connections
        client_info: zmq address/protocol dict for client connections
        """

        super(Hub, self).__init__(**kwargs)
        self.registration_timeout = max(5000, 2 * self.heartmonitor.period)

        # validate connection dicts:
        for k, v in self.client_info.iteritems():
            if k == 'task':
                util.validate_url_container(v[1])
            else:
                util.validate_url_container(v)
        # util.validate_url_container(self.client_info)
        util.validate_url_container(self.engine_info)

        # register our callbacks
        self.query.on_recv(self.dispatch_query)
        self.monitor.on_recv(self.dispatch_monitor_traffic)

        self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
        self.heartmonitor.add_new_heart_handler(self.handle_new_heart)

        self.monitor_handlers = {
            'in': self.save_queue_request,
            'out': self.save_queue_result,
            'intask': self.save_task_request,
            'outtask': self.save_task_result,
            'tracktask': self.save_task_destination,
            'incontrol': _passer,
            'outcontrol': _passer,
            'iopub': self.save_iopub_message,
        }

        self.query_handlers = {
            'queue_request': self.queue_status,
            'result_request': self.get_results,
            'history_request': self.get_history,
            'db_request': self.db_query,
            'purge_request': self.purge_results,
            'load_request': self.check_load,
            'resubmit_request': self.resubmit_task,
            'shutdown_request': self.shutdown_request,
            'registration_request': self.register_engine,
            'unregistration_request': self.unregister_engine,
            'connection_request': self.connection_request,
        }

        self.log.info("hub::created hub")

    @property
    def _next_id(self):
        """gemerate a new ID.
        
        No longer reuse old ids, just count from 0."""
        newid = self._idcounter
        self._idcounter += 1
        return newid
        # newid = 0
        # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
        # # print newid, self.ids, self.incoming_registrations
        # while newid in self.ids or newid in incoming:
        #     newid += 1
        # return newid

    #-----------------------------------------------------------------------------
    # message validation
    #-----------------------------------------------------------------------------

    def _validate_targets(self, targets):
        """turn any valid targets argument into a list of integer ids"""
        if targets is None:
            # default to all
            targets = self.ids

        if isinstance(targets, (int, str, unicode)):
            # only one target specified
            targets = [targets]
        _targets = []
        for t in targets:
            # map raw identities to ids
            if isinstance(t, (str, unicode)):
                t = self.by_ident.get(t, t)
            _targets.append(t)
        targets = _targets
        bad_targets = [t for t in targets if t not in self.ids]
        if bad_targets:
            raise IndexError("No Such Engine: %r" % bad_targets)
        if not targets:
            raise IndexError("No Engines Registered")
        return targets

    #-----------------------------------------------------------------------------
    # dispatch methods (1 per stream)
    #-----------------------------------------------------------------------------

    # def dispatch_registration_request(self, msg):
    #     """"""
    #     self.log.debug("registration::dispatch_register_request(%s)"%msg)
    #     idents,msg = self.session.feed_identities(msg)
    #     if not idents:
    #         self.log.error("Bad Query Message: %s"%msg, exc_info=True)
    #         return
    #     try:
    #         msg = self.session.unpack_message(msg,content=True)
    #     except:
    #         self.log.error("registration::got bad registration message: %s"%msg, exc_info=True)
    #         return
    #
    #     msg_type = msg['msg_type']
    #     content = msg['content']
    #
    #     handler = self.query_handlers.get(msg_type, None)
    #     if handler is None:
    #         self.log.error("registration::got bad registration message: %s"%msg)
    #     else:
    #         handler(idents, msg)

    def dispatch_monitor_traffic(self, msg):
        """all ME and Task queue messages come through here, as well as
        IOPub traffic."""
        self.log.debug("monitor traffic: %s" % msg[:2])
        switch = msg[0]
        idents, msg = self.session.feed_identities(msg[1:])
        if not idents:
            self.log.error("Bad Monitor Message: %s" % msg)
            return
        handler = self.monitor_handlers.get(switch, None)
        if handler is not None:
            handler(idents, msg)
        else:
            self.log.error("Invalid monitor topic: %s" % switch)

    def dispatch_query(self, msg):
        """Route registration requests and queries from clients."""
        idents, msg = self.session.feed_identities(msg)
        if not idents:
            self.log.error("Bad Query Message: %s" % msg)
            return
        client_id = idents[0]
        try:
            msg = self.session.unpack_message(msg, content=True)
        except:
            content = error.wrap_exception()
            self.log.error("Bad Query Message: %s" % msg, exc_info=True)
            self.session.send(self.query,
                              "hub_error",
                              ident=client_id,
                              content=content)
            return

        # print client_id, header, parent, content
        #switch on message type:
        msg_type = msg['msg_type']
        self.log.info("client::client %s requested %s" % (client_id, msg_type))
        handler = self.query_handlers.get(msg_type, None)
        try:
            assert handler is not None, "Bad Message Type: %s" % msg_type
        except:
            content = error.wrap_exception()
            self.log.error("Bad Message Type: %s" % msg_type, exc_info=True)
            self.session.send(self.query,
                              "hub_error",
                              ident=client_id,
                              content=content)
            return
        else:
            handler(idents, msg)

    def dispatch_db(self, msg):
        """"""
        raise NotImplementedError

    #---------------------------------------------------------------------------
    # handler methods (1 per event)
    #---------------------------------------------------------------------------

    #----------------------- Heartbeat --------------------------------------

    def handle_new_heart(self, heart):
        """handler to attach to heartbeater.
        Called when a new heart starts to beat.
        Triggers completion of registration."""
        self.log.debug("heartbeat::handle_new_heart(%r)" % heart)
        if heart not in self.incoming_registrations:
            self.log.info("heartbeat::ignoring new heart: %r" % heart)
        else:
            self.finish_registration(heart)

    def handle_heart_failure(self, heart):
        """handler to attach to heartbeater.
        called when a previously registered heart fails to respond to beat request.
        triggers unregistration"""
        self.log.debug("heartbeat::handle_heart_failure(%r)" % heart)
        eid = self.hearts.get(heart, None)
        queue = self.engines[eid].queue
        if eid is None:
            self.log.info("heartbeat::ignoring heart failure %r" % heart)
        else:
            self.unregister_engine(heart,
                                   dict(content=dict(id=eid, queue=queue)))

    #----------------------- MUX Queue Traffic ------------------------------

    def save_queue_request(self, idents, msg):
        if len(idents) < 2:
            self.log.error("invalid identity prefix: %s" % idents)
            return
        queue_id, client_id = idents[:2]
        try:
            msg = self.session.unpack_message(msg, content=False)
        except:
            self.log.error("queue::client %r sent invalid message to %r: %s" %
                           (client_id, queue_id, msg),
                           exc_info=True)
            return

        eid = self.by_ident.get(queue_id, None)
        if eid is None:
            self.log.error("queue::target %r not registered" % queue_id)
            self.log.debug("queue::    valid are: %s" % (self.by_ident.keys()))
            return

        header = msg['header']
        msg_id = header['msg_id']
        record = init_record(msg)
        record['engine_uuid'] = queue_id
        record['client_uuid'] = client_id
        record['queue'] = 'mux'

        try:
            # it's posible iopub arrived first:
            existing = self.db.get_record(msg_id)
            for key, evalue in existing.iteritems():
                rvalue = record[key]
                if evalue and rvalue and evalue != rvalue:
                    self.log.error(
                        "conflicting initial state for record: %s:%s <> %s" %
                        (msg_id, rvalue, evalue))
                elif evalue and not rvalue:
                    record[key] = evalue
            self.db.update_record(msg_id, record)
        except KeyError:
            self.db.add_record(msg_id, record)

        self.pending.add(msg_id)
        self.queues[eid].append(msg_id)

    def save_queue_result(self, idents, msg):
        if len(idents) < 2:
            self.log.error("invalid identity prefix: %s" % idents)
            return

        client_id, queue_id = idents[:2]
        try:
            msg = self.session.unpack_message(msg, content=False)
        except:
            self.log.error("queue::engine %r sent invalid message to %r: %s" %
                           (queue_id, client_id, msg),
                           exc_info=True)
            return

        eid = self.by_ident.get(queue_id, None)
        if eid is None:
            self.log.error("queue::unknown engine %r is sending a reply: " %
                           queue_id)
            # self.log.debug("queue::       %s"%msg[2:])
            return

        parent = msg['parent_header']
        if not parent:
            return
        msg_id = parent['msg_id']
        if msg_id in self.pending:
            self.pending.remove(msg_id)
            self.all_completed.add(msg_id)
            self.queues[eid].remove(msg_id)
            self.completed[eid].append(msg_id)
        elif msg_id not in self.all_completed:
            # it could be a result from a dead engine that died before delivering the
            # result
            self.log.warn("queue:: unknown msg finished %s" % msg_id)
            return
        # update record anyway, because the unregistration could have been premature
        rheader = msg['header']
        completed = datetime.strptime(rheader['date'], util.ISO8601)
        started = rheader.get('started', None)
        if started is not None:
            started = datetime.strptime(started, util.ISO8601)
        result = {
            'result_header': rheader,
            'result_content': msg['content'],
            'started': started,
            'completed': completed
        }

        result['result_buffers'] = msg['buffers']
        try:
            self.db.update_record(msg_id, result)
        except Exception:
            self.log.error("DB Error updating record %r" % msg_id,
                           exc_info=True)

    #--------------------- Task Queue Traffic ------------------------------

    def save_task_request(self, idents, msg):
        """Save the submission of a task."""
        client_id = idents[0]

        try:
            msg = self.session.unpack_message(msg, content=False)
        except:
            self.log.error("task::client %r sent invalid task message: %s" %
                           (client_id, msg),
                           exc_info=True)
            return
        record = init_record(msg)

        record['client_uuid'] = client_id
        record['queue'] = 'task'
        header = msg['header']
        msg_id = header['msg_id']
        self.pending.add(msg_id)
        self.unassigned.add(msg_id)
        try:
            # it's posible iopub arrived first:
            existing = self.db.get_record(msg_id)
            for key, evalue in existing.iteritems():
                rvalue = record[key]
                if evalue and rvalue and evalue != rvalue:
                    self.log.error(
                        "conflicting initial state for record: %s:%s <> %s" %
                        (msg_id, rvalue, evalue))
                elif evalue and not rvalue:
                    record[key] = evalue
            self.db.update_record(msg_id, record)
        except KeyError:
            self.db.add_record(msg_id, record)
        except Exception:
            self.log.error("DB Error saving task request %r" % msg_id,
                           exc_info=True)

    def save_task_result(self, idents, msg):
        """save the result of a completed task."""
        client_id = idents[0]
        try:
            msg = self.session.unpack_message(msg, content=False)
        except:
            self.log.error("task::invalid task result message send to %r: %s" %
                           (client_id, msg),
                           exc_info=True)
            raise
            return

        parent = msg['parent_header']
        if not parent:
            # print msg
            self.log.warn("Task %r had no parent!" % msg)
            return
        msg_id = parent['msg_id']
        if msg_id in self.unassigned:
            self.unassigned.remove(msg_id)

        header = msg['header']
        engine_uuid = header.get('engine', None)
        eid = self.by_ident.get(engine_uuid, None)

        if msg_id in self.pending:
            self.pending.remove(msg_id)
            self.all_completed.add(msg_id)
            if eid is not None:
                self.completed[eid].append(msg_id)
                if msg_id in self.tasks[eid]:
                    self.tasks[eid].remove(msg_id)
            completed = datetime.strptime(header['date'], util.ISO8601)
            started = header.get('started', None)
            if started is not None:
                started = datetime.strptime(started, util.ISO8601)
            result = {
                'result_header': header,
                'result_content': msg['content'],
                'started': started,
                'completed': completed,
                'engine_uuid': engine_uuid
            }

            result['result_buffers'] = msg['buffers']
            try:
                self.db.update_record(msg_id, result)
            except Exception:
                self.log.error("DB Error saving task request %r" % msg_id,
                               exc_info=True)

        else:
            self.log.debug("task::unknown task %s finished" % msg_id)

    def save_task_destination(self, idents, msg):
        try:
            msg = self.session.unpack_message(msg, content=True)
        except:
            self.log.error("task::invalid task tracking message",
                           exc_info=True)
            return
        content = msg['content']
        # print (content)
        msg_id = content['msg_id']
        engine_uuid = content['engine_id']
        eid = self.by_ident[engine_uuid]

        self.log.info("task::task %s arrived on %s" % (msg_id, eid))
        if msg_id in self.unassigned:
            self.unassigned.remove(msg_id)
        # else:
        #     self.log.debug("task::task %s not listed as MIA?!"%(msg_id))

        self.tasks[eid].append(msg_id)
        # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
        try:
            self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
        except Exception:
            self.log.error("DB Error saving task destination %r" % msg_id,
                           exc_info=True)

    def mia_task_request(self, idents, msg):
        raise NotImplementedError
        client_id = idents[0]
        # content = dict(mia=self.mia,status='ok')
        # self.session.send('mia_reply', content=content, idents=client_id)

    #--------------------- IOPub Traffic ------------------------------

    def save_iopub_message(self, topics, msg):
        """save an iopub message into the db"""
        # print (topics)
        try:
            msg = self.session.unpack_message(msg, content=True)
        except:
            self.log.error("iopub::invalid IOPub message", exc_info=True)
            return

        parent = msg['parent_header']
        if not parent:
            self.log.error("iopub::invalid IOPub message: %s" % msg)
            return
        msg_id = parent['msg_id']
        msg_type = msg['msg_type']
        content = msg['content']

        # ensure msg_id is in db
        try:
            rec = self.db.get_record(msg_id)
        except KeyError:
            rec = empty_record()
            rec['msg_id'] = msg_id
            self.db.add_record(msg_id, rec)
        # stream
        d = {}
        if msg_type == 'stream':
            name = content['name']
            s = rec[name] or ''
            d[name] = s + content['data']

        elif msg_type == 'pyerr':
            d['pyerr'] = content
        elif msg_type == 'pyin':
            d['pyin'] = content['code']
        else:
            d[msg_type] = content.get('data', '')

        try:
            self.db.update_record(msg_id, d)
        except Exception:
            self.log.error("DB Error saving iopub message %r" % msg_id,
                           exc_info=True)

    #-------------------------------------------------------------------------
    # Registration requests
    #-------------------------------------------------------------------------

    def connection_request(self, client_id, msg):
        """Reply with connection addresses for clients."""
        self.log.info("client::client %s connected" % client_id)
        content = dict(status='ok')
        content.update(self.client_info)
        jsonable = {}
        for k, v in self.keytable.iteritems():
            if v not in self.dead_engines:
                jsonable[str(k)] = v
        content['engines'] = jsonable
        self.session.send(self.query,
                          'connection_reply',
                          content,
                          parent=msg,
                          ident=client_id)

    def register_engine(self, reg, msg):
        """Register a new engine."""
        content = msg['content']
        try:
            queue = content['queue']
        except KeyError:
            self.log.error("registration::queue not specified", exc_info=True)
            return
        heart = content.get('heartbeat', None)
        """register a new engine, and create the socket(s) necessary"""
        eid = self._next_id
        # print (eid, queue, reg, heart)

        self.log.debug("registration::register_engine(%i, %r, %r, %r)" %
                       (eid, queue, reg, heart))

        content = dict(id=eid, status='ok')
        content.update(self.engine_info)
        # check if requesting available IDs:
        if queue in self.by_ident:
            try:
                raise KeyError("queue_id %r in use" % queue)
            except:
                content = error.wrap_exception()
                self.log.error("queue_id %r in use" % queue, exc_info=True)
        elif heart in self.hearts:  # need to check unique hearts?
            try:
                raise KeyError("heart_id %r in use" % heart)
            except:
                self.log.error("heart_id %r in use" % heart, exc_info=True)
                content = error.wrap_exception()
        else:
            for h, pack in self.incoming_registrations.iteritems():
                if heart == h:
                    try:
                        raise KeyError("heart_id %r in use" % heart)
                    except:
                        self.log.error("heart_id %r in use" % heart,
                                       exc_info=True)
                        content = error.wrap_exception()
                    break
                elif queue == pack[1]:
                    try:
                        raise KeyError("queue_id %r in use" % queue)
                    except:
                        self.log.error("queue_id %r in use" % queue,
                                       exc_info=True)
                        content = error.wrap_exception()
                    break

        msg = self.session.send(self.query,
                                "registration_reply",
                                content=content,
                                ident=reg)

        if content['status'] == 'ok':
            if heart in self.heartmonitor.hearts:
                # already beating
                self.incoming_registrations[heart] = (eid, queue, reg[0], None)
                self.finish_registration(heart)
            else:
                purge = lambda: self._purge_stalled_registration(heart)
                dc = ioloop.DelayedCallback(purge, self.registration_timeout,
                                            self.loop)
                dc.start()
                self.incoming_registrations[heart] = (eid, queue, reg[0], dc)
        else:
            self.log.error("registration::registration %i failed: %s" %
                           (eid, content['evalue']))
        return eid

    def unregister_engine(self, ident, msg):
        """Unregister an engine that explicitly requested to leave."""
        try:
            eid = msg['content']['id']
        except:
            self.log.error(
                "registration::bad engine id for unregistration: %s" % ident,
                exc_info=True)
            return
        self.log.info("registration::unregister_engine(%s)" % eid)
        # print (eid)
        uuid = self.keytable[eid]
        content = dict(id=eid, queue=uuid)
        self.dead_engines.add(uuid)
        # self.ids.remove(eid)
        # uuid = self.keytable.pop(eid)
        #
        # ec = self.engines.pop(eid)
        # self.hearts.pop(ec.heartbeat)
        # self.by_ident.pop(ec.queue)
        # self.completed.pop(eid)
        handleit = lambda: self._handle_stranded_msgs(eid, uuid)
        dc = ioloop.DelayedCallback(handleit, self.registration_timeout,
                                    self.loop)
        dc.start()
        ############## TODO: HANDLE IT ################

        if self.notifier:
            self.session.send(self.notifier,
                              "unregistration_notification",
                              content=content)

    def _handle_stranded_msgs(self, eid, uuid):
        """Handle messages known to be on an engine when the engine unregisters.
        
        It is possible that this will fire prematurely - that is, an engine will
        go down after completing a result, and the client will be notified
        that the result failed and later receive the actual result.
        """

        outstanding = self.queues[eid]

        for msg_id in outstanding:
            self.pending.remove(msg_id)
            self.all_completed.add(msg_id)
            try:
                raise error.EngineError(
                    "Engine %r died while running task %r" % (eid, msg_id))
            except:
                content = error.wrap_exception()
            # build a fake header:
            header = {}
            header['engine'] = uuid
            header['date'] = datetime.now()
            rec = dict(result_content=content,
                       result_header=header,
                       result_buffers=[])
            rec['completed'] = header['date']
            rec['engine_uuid'] = uuid
            try:
                self.db.update_record(msg_id, rec)
            except Exception:
                self.log.error("DB Error handling stranded msg %r" % msg_id,
                               exc_info=True)

    def finish_registration(self, heart):
        """Second half of engine registration, called after our HeartMonitor
        has received a beat from the Engine's Heart."""
        try:
            (eid, queue, reg, purge) = self.incoming_registrations.pop(heart)
        except KeyError:
            self.log.error(
                "registration::tried to finish nonexistant registration",
                exc_info=True)
            return
        self.log.info("registration::finished registering engine %i:%r" %
                      (eid, queue))
        if purge is not None:
            purge.stop()
        control = queue
        self.ids.add(eid)
        self.keytable[eid] = queue
        self.engines[eid] = EngineConnector(id=eid,
                                            queue=queue,
                                            registration=reg,
                                            control=control,
                                            heartbeat=heart)
        self.by_ident[queue] = eid
        self.queues[eid] = list()
        self.tasks[eid] = list()
        self.completed[eid] = list()
        self.hearts[heart] = eid
        content = dict(id=eid, queue=self.engines[eid].queue)
        if self.notifier:
            self.session.send(self.notifier,
                              "registration_notification",
                              content=content)
        self.log.info("engine::Engine Connected: %i" % eid)

    def _purge_stalled_registration(self, heart):
        if heart in self.incoming_registrations:
            eid = self.incoming_registrations.pop(heart)[0]
            self.log.info("registration::purging stalled registration: %i" %
                          eid)
        else:
            pass

    #-------------------------------------------------------------------------
    # Client Requests
    #-------------------------------------------------------------------------

    def shutdown_request(self, client_id, msg):
        """handle shutdown request."""
        self.session.send(self.query,
                          'shutdown_reply',
                          content={'status': 'ok'},
                          ident=client_id)
        # also notify other clients of shutdown
        self.session.send(self.notifier,
                          'shutdown_notice',
                          content={'status': 'ok'})
        dc = ioloop.DelayedCallback(lambda: self._shutdown(), 1000, self.loop)
        dc.start()

    def _shutdown(self):
        self.log.info("hub::hub shutting down.")
        time.sleep(0.1)
        sys.exit(0)

    def check_load(self, client_id, msg):
        content = msg['content']
        try:
            targets = content['targets']
            targets = self._validate_targets(targets)
        except:
            content = error.wrap_exception()
            self.session.send(self.query,
                              "hub_error",
                              content=content,
                              ident=client_id)
            return

        content = dict(status='ok')
        # loads = {}
        for t in targets:
            content[bytes(t)] = len(self.queues[t]) + len(self.tasks[t])
        self.session.send(self.query,
                          "load_reply",
                          content=content,
                          ident=client_id)

    def queue_status(self, client_id, msg):
        """Return the Queue status of one or more targets.
        if verbose: return the msg_ids
        else: return len of each type.
        keys: queue (pending MUX jobs)
            tasks (pending Task jobs)
            completed (finished jobs from both queues)"""
        content = msg['content']
        targets = content['targets']
        try:
            targets = self._validate_targets(targets)
        except:
            content = error.wrap_exception()
            self.session.send(self.query,
                              "hub_error",
                              content=content,
                              ident=client_id)
            return
        verbose = content.get('verbose', False)
        content = dict(status='ok')
        for t in targets:
            queue = self.queues[t]
            completed = self.completed[t]
            tasks = self.tasks[t]
            if not verbose:
                queue = len(queue)
                completed = len(completed)
                tasks = len(tasks)
            content[bytes(t)] = {
                'queue': queue,
                'completed': completed,
                'tasks': tasks
            }
        content['unassigned'] = list(self.unassigned) if verbose else len(
            self.unassigned)

        self.session.send(self.query,
                          "queue_reply",
                          content=content,
                          ident=client_id)

    def purge_results(self, client_id, msg):
        """Purge results from memory. This method is more valuable before we move
        to a DB based message storage mechanism."""
        content = msg['content']
        msg_ids = content.get('msg_ids', [])
        reply = dict(status='ok')
        if msg_ids == 'all':
            try:
                self.db.drop_matching_records(dict(completed={'$ne': None}))
            except Exception:
                reply = error.wrap_exception()
        else:
            for msg_id in msg_ids:
                if msg_id in self.all_completed:
                    self.db.drop_record(msg_id)
                else:
                    if msg_id in self.pending:
                        try:
                            raise IndexError("msg pending: %r" % msg_id)
                        except:
                            reply = error.wrap_exception()
                    else:
                        try:
                            raise IndexError("No such msg: %r" % msg_id)
                        except:
                            reply = error.wrap_exception()
                    break
            eids = content.get('engine_ids', [])
            for eid in eids:
                if eid not in self.engines:
                    try:
                        raise IndexError("No such engine: %i" % eid)
                    except:
                        reply = error.wrap_exception()
                    break
                msg_ids = self.completed.pop(eid)
                uid = self.engines[eid].queue
                try:
                    self.db.drop_matching_records(
                        dict(engine_uuid=uid, completed={'$ne': None}))
                except Exception:
                    reply = error.wrap_exception()
                    break

        self.session.send(self.query,
                          'purge_reply',
                          content=reply,
                          ident=client_id)

    def resubmit_task(self, client_id, msg, buffers):
        """Resubmit a task."""
        raise NotImplementedError

    def _extract_record(self, rec):
        """decompose a TaskRecord dict into subsection of reply for get_result"""
        io_dict = {}
        for key in 'pyin pyout pyerr stdout stderr'.split():
            io_dict[key] = rec[key]
        content = {
            'result_content': rec['result_content'],
            'header': rec['header'],
            'result_header': rec['result_header'],
            'io': io_dict,
        }
        if rec['result_buffers']:
            buffers = map(str, rec['result_buffers'])
        else:
            buffers = []

        return content, buffers

    def get_results(self, client_id, msg):
        """Get the result of 1 or more messages."""
        content = msg['content']
        msg_ids = sorted(set(content['msg_ids']))
        statusonly = content.get('status_only', False)
        pending = []
        completed = []
        content = dict(status='ok')
        content['pending'] = pending
        content['completed'] = completed
        buffers = []
        if not statusonly:
            try:
                matches = self.db.find_records(dict(msg_id={'$in': msg_ids}))
                # turn match list into dict, for faster lookup
                records = {}
                for rec in matches:
                    records[rec['msg_id']] = rec
            except Exception:
                content = error.wrap_exception()
                self.session.send(self.query,
                                  "result_reply",
                                  content=content,
                                  parent=msg,
                                  ident=client_id)
                return
        else:
            records = {}
        for msg_id in msg_ids:
            if msg_id in self.pending:
                pending.append(msg_id)
            elif msg_id in self.all_completed or msg_id in records:
                completed.append(msg_id)
                if not statusonly:
                    c, bufs = self._extract_record(records[msg_id])
                    content[msg_id] = c
                    buffers.extend(bufs)
            else:
                try:
                    raise KeyError('No such message: ' + msg_id)
                except:
                    content = error.wrap_exception()
                break
        self.session.send(self.query,
                          "result_reply",
                          content=content,
                          parent=msg,
                          ident=client_id,
                          buffers=buffers)

    def get_history(self, client_id, msg):
        """Get a list of all msg_ids in our DB records"""
        try:
            msg_ids = self.db.get_history()
        except Exception as e:
            content = error.wrap_exception()
        else:
            content = dict(status='ok', history=msg_ids)

        self.session.send(self.query,
                          "history_reply",
                          content=content,
                          parent=msg,
                          ident=client_id)

    def db_query(self, client_id, msg):
        """Perform a raw query on the task record database."""
        content = msg['content']
        query = content.get('query', {})
        keys = content.get('keys', None)
        query = util.extract_dates(query)
        buffers = []
        empty = list()

        try:
            records = self.db.find_records(query, keys)
        except Exception as e:
            content = error.wrap_exception()
        else:
            # extract buffers from reply content:
            if keys is not None:
                buffer_lens = [] if 'buffers' in keys else None
                result_buffer_lens = [] if 'result_buffers' in keys else None
            else:
                buffer_lens = []
                result_buffer_lens = []

            for rec in records:
                # buffers may be None, so double check
                if buffer_lens is not None:
                    b = rec.pop('buffers', empty) or empty
                    buffer_lens.append(len(b))
                    buffers.extend(b)
                if result_buffer_lens is not None:
                    rb = rec.pop('result_buffers', empty) or empty
                    result_buffer_lens.append(len(rb))
                    buffers.extend(rb)
            content = dict(status='ok',
                           records=records,
                           buffer_lens=buffer_lens,
                           result_buffer_lens=result_buffer_lens)

        self.session.send(self.query,
                          "db_reply",
                          content=content,
                          parent=msg,
                          ident=client_id,
                          buffers=buffers)
예제 #15
0
class LocalAuthenticator(Authenticator):
    """Base class for Authenticators that work with local *ix users
    
    Checks for local users, and can attempt to create them if they exist.
    """
    
    create_system_users = Bool(False, config=True,
        help="""If a user is added that doesn't exist on the system,
        should I try to create the system user?
        """
    )

    group_whitelist = Set(
        config=True,
        help="Automatically whitelist anyone in this group.",
    )

    def _group_whitelist_changed(self, name, old, new):
        if self.whitelist:
            self.log.warn(
                "Ignoring username whitelist because group whitelist supplied!"
            )

    def check_whitelist(self, username):
        if self.group_whitelist:
            return self.check_group_whitelist(username)
        else:
            return super().check_whitelist(username)

    def check_group_whitelist(self, username):
        if not self.group_whitelist:
            return False
        for grnam in self.group_whitelist:
            try:
                group = getgrnam(grnam)
            except KeyError:
                self.log.error('No such group: [%s]' % grnam)
                continue
            if username in group.gr_mem:
                return True
        return False

    @gen.coroutine
    def add_user(self, user):
        """Add a new user
        
        By default, this just adds the user to the whitelist.
        
        Subclasses may do more extensive things,
        such as adding actual unix users.
        """
        user_exists = yield gen.maybe_future(self.system_user_exists(user))
        if not user_exists:
            if self.create_system_users:
                yield gen.maybe_future(self.add_system_user(user))
            else:
                raise KeyError("User %s does not exist." % user.name)
        
        yield gen.maybe_future(super().add_user(user))
    
    @staticmethod
    def system_user_exists(user):
        """Check if the user exists on the system"""
        try:
            pwd.getpwnam(user.name)
        except KeyError:
            return False
        else:
            return True
    
    @staticmethod
    def add_system_user(user):
        """Create a new *ix user on the system. Works on FreeBSD and Linux, at least."""
        name = user.name
        for useradd in (
            ['pw', 'useradd', '-m'],
            ['useradd', '-m'],
        ):
            try:
                check_output(['which', useradd[0]])
            except CalledProcessError:
                continue
            else:
                break
        else:
            raise RuntimeError("I don't know how to add users on this system.")
    
        check_call(useradd + [name])
예제 #16
0
class JupyterHub(Application):
    """An Application for starting a Multi-User Jupyter Notebook server."""
    name = 'jupyterhub'

    description = """Start a multi-user Jupyter Notebook server
    
    Spawns a configurable-http-proxy and multi-user Hub,
    which authenticates users and spawns single-user Notebook servers
    on behalf of users.
    """

    examples = """
    
    generate default config file:
    
        jupyterhub --generate-config -f /etc/jupyterhub/jupyterhub.py
    
    spawn the server on 10.0.1.2:443 with https:
    
        jupyterhub --ip 10.0.1.2 --port 443 --ssl-key my_ssl.key --ssl-cert my_ssl.cert
    """

    aliases = Dict(aliases)
    flags = Dict(flags)

    subcommands = {'token': (NewToken, "Generate an API token for a user")}

    classes = List([
        Spawner,
        LocalProcessSpawner,
        Authenticator,
        PAMAuthenticator,
    ])

    config_file = Unicode(
        'jupyterhub_config.py',
        config=True,
        help="The config file to load",
    )
    generate_config = Bool(
        False,
        config=True,
        help="Generate default config file",
    )
    answer_yes = Bool(
        False,
        config=True,
        help="Answer yes to any questions (e.g. confirm overwrite)")
    pid_file = Unicode('',
                       config=True,
                       help="""File to write PID
        Useful for daemonizing jupyterhub.
        """)
    last_activity_interval = Integer(
        300,
        config=True,
        help=
        "Interval (in seconds) at which to update last-activity timestamps.")
    proxy_check_interval = Integer(
        30,
        config=True,
        help="Interval (in seconds) at which to check if the proxy is running."
    )

    data_files_path = Unicode(
        DATA_FILES_PATH,
        config=True,
        help=
        "The location of jupyterhub data files (e.g. /usr/local/share/jupyter/hub)"
    )

    ssl_key = Unicode(
        '',
        config=True,
        help="""Path to SSL key file for the public facing interface of the proxy
        
        Use with ssl_cert
        """)
    ssl_cert = Unicode(
        '',
        config=True,
        help=
        """Path to SSL certificate file for the public facing interface of the proxy
        
        Use with ssl_key
        """)
    ip = Unicode('', config=True, help="The public facing ip of the proxy")
    port = Integer(8000,
                   config=True,
                   help="The public facing port of the proxy")
    base_url = URLPrefix('/',
                         config=True,
                         help="The base URL of the entire application")

    jinja_environment_options = Dict(
        config=True,
        help="Supply extra arguments that will be passed to Jinja environment."
    )

    proxy_cmd = Unicode('configurable-http-proxy',
                        config=True,
                        help="""The command to start the http proxy.
        
        Only override if configurable-http-proxy is not on your PATH
        """)
    debug_proxy = Bool(False,
                       config=True,
                       help="show debug output in configurable-http-proxy")
    proxy_auth_token = Unicode(config=True,
                               help="""The Proxy Auth token.

        Loaded from the CONFIGPROXY_AUTH_TOKEN env variable by default.
        """)

    def _proxy_auth_token_default(self):
        token = os.environ.get('CONFIGPROXY_AUTH_TOKEN', None)
        if not token:
            self.log.warn('\n'.join([
                "",
                "Generating CONFIGPROXY_AUTH_TOKEN. Restarting the Hub will require restarting the proxy.",
                "Set CONFIGPROXY_AUTH_TOKEN env or JupyterHub.proxy_auth_token config to avoid this message.",
                "",
            ]))
            token = orm.new_token()
        return token

    proxy_api_ip = Unicode('localhost',
                           config=True,
                           help="The ip for the proxy API handlers")
    proxy_api_port = Integer(config=True,
                             help="The port for the proxy API handlers")

    def _proxy_api_port_default(self):
        return self.port + 1

    hub_port = Integer(8081, config=True, help="The port for this process")
    hub_ip = Unicode('localhost', config=True, help="The ip for this process")

    hub_prefix = URLPrefix(
        '/hub/',
        config=True,
        help="The prefix for the hub server. Must not be '/'")

    def _hub_prefix_default(self):
        return url_path_join(self.base_url, '/hub/')

    def _hub_prefix_changed(self, name, old, new):
        if new == '/':
            raise TraitError("'/' is not a valid hub prefix")
        if not new.startswith(self.base_url):
            self.hub_prefix = url_path_join(self.base_url, new)

    cookie_secret = Bytes(config=True,
                          env='JPY_COOKIE_SECRET',
                          help="""The cookie secret to use to encrypt cookies.

        Loaded from the JPY_COOKIE_SECRET env variable by default.
        """)

    cookie_secret_file = Unicode(
        'jupyterhub_cookie_secret',
        config=True,
        help="""File in which to store the cookie secret.""")

    authenticator_class = Type(PAMAuthenticator,
                               Authenticator,
                               config=True,
                               help="""Class for authenticating users.
        
        This should be a class with the following form:
        
        - constructor takes one kwarg: `config`, the IPython config object.
        
        - is a tornado.gen.coroutine
        - returns username on success, None on failure
        - takes two arguments: (handler, data),
          where `handler` is the calling web.RequestHandler,
          and `data` is the POST form data from the login page.
        """)

    authenticator = Instance(Authenticator)

    def _authenticator_default(self):
        return self.authenticator_class(parent=self, db=self.db)

    # class for spawning single-user servers
    spawner_class = Type(
        LocalProcessSpawner,
        Spawner,
        config=True,
        help="""The class to use for spawning single-user servers.
        
        Should be a subclass of Spawner.
        """)

    db_url = Unicode(
        'sqlite:///jupyterhub.sqlite',
        config=True,
        help="url for the database. e.g. `sqlite:///jupyterhub.sqlite`")

    def _db_url_changed(self, name, old, new):
        if '://' not in new:
            # assume sqlite, if given as a plain filename
            self.db_url = 'sqlite:///%s' % new

    db_kwargs = Dict(
        config=True,
        help="""Include any kwargs to pass to the database connection.
        See sqlalchemy.create_engine for details.
        """)

    reset_db = Bool(False, config=True, help="Purge and reset the database.")
    debug_db = Bool(
        False,
        config=True,
        help="log all database transactions. This has A LOT of output")
    db = Any()
    session_factory = Any()

    admin_access = Bool(
        False,
        config=True,
        help="""Grant admin users permission to access single-user servers.
        
        Users should be properly informed if this is enabled.
        """)
    admin_users = Set(config=True,
                      help="""set of usernames of admin users

        If unspecified, only the user that launches the server will be admin.
        """)
    tornado_settings = Dict(config=True)

    cleanup_servers = Bool(
        True,
        config=True,
        help="""Whether to shutdown single-user servers when the Hub shuts down.
        
        Disable if you want to be able to teardown the Hub while leaving the single-user servers running.
        
        If both this and cleanup_proxy are False, sending SIGINT to the Hub will
        only shutdown the Hub, leaving everything else running.
        
        The Hub should be able to resume from database state.
        """)

    cleanup_proxy = Bool(
        True,
        config=True,
        help="""Whether to shutdown the proxy when the Hub shuts down.
        
        Disable if you want to be able to teardown the Hub while leaving the proxy running.
        
        Only valid if the proxy was starting by the Hub process.
        
        If both this and cleanup_servers are False, sending SIGINT to the Hub will
        only shutdown the Hub, leaving everything else running.
        
        The Hub should be able to resume from database state.
        """)

    handlers = List()

    _log_formatter_cls = CoroutineLogFormatter
    http_server = None
    proxy_process = None
    io_loop = None

    def _log_level_default(self):
        return logging.INFO

    def _log_datefmt_default(self):
        """Exclude date from default date format"""
        return "%Y-%m-%d %H:%M:%S"

    def _log_format_default(self):
        """override default log format to include time"""
        return "%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s %(module)s:%(lineno)d]%(end_color)s %(message)s"

    extra_log_file = Unicode("",
                             config=True,
                             help="Set a logging.FileHandler on this file.")
    extra_log_handlers = List(
        Instance(logging.Handler),
        config=True,
        help="Extra log handlers to set on JupyterHub logger",
    )

    def init_logging(self):
        # This prevents double log messages because tornado use a root logger that
        # self.log is a child of. The logging module dipatches log messages to a log
        # and all of its ancenstors until propagate is set to False.
        self.log.propagate = False

        if self.extra_log_file:
            self.extra_log_handlers.append(
                logging.FileHandler(self.extra_log_file))

        _formatter = self._log_formatter_cls(
            fmt=self.log_format,
            datefmt=self.log_datefmt,
        )
        for handler in self.extra_log_handlers:
            if handler.formatter is None:
                handler.setFormatter(_formatter)
            self.log.addHandler(handler)

        # hook up tornado 3's loggers to our app handlers
        for log in (app_log, access_log, gen_log):
            # ensure all log statements identify the application they come from
            log.name = self.log.name
        logger = logging.getLogger('tornado')
        logger.propagate = True
        logger.parent = self.log
        logger.setLevel(self.log.level)

    def init_ports(self):
        if self.hub_port == self.port:
            raise TraitError(
                "The hub and proxy cannot both listen on port %i" % self.port)
        if self.hub_port == self.proxy_api_port:
            raise TraitError(
                "The hub and proxy API cannot both listen on port %i" %
                self.hub_port)
        if self.proxy_api_port == self.port:
            raise TraitError(
                "The proxy's public and API ports cannot both be %i" %
                self.port)

    @staticmethod
    def add_url_prefix(prefix, handlers):
        """add a url prefix to handlers"""
        for i, tup in enumerate(handlers):
            lis = list(tup)
            lis[0] = url_path_join(prefix, tup[0])
            handlers[i] = tuple(lis)
        return handlers

    def init_handlers(self):
        h = []
        h.extend(handlers.default_handlers)
        h.extend(apihandlers.default_handlers)
        # load handlers from the authenticator
        h.extend(self.authenticator.get_handlers(self))

        self.handlers = self.add_url_prefix(self.hub_prefix, h)

        # some extra handlers, outside hub_prefix
        self.handlers.extend([
            (r"%s" % self.hub_prefix.rstrip('/'), web.RedirectHandler, {
                "url": self.hub_prefix,
                "permanent": False,
            }),
            (r"(?!%s).*" % self.hub_prefix, handlers.PrefixRedirectHandler),
            (r'(.*)', handlers.Template404),
        ])

    def _check_db_path(self, path):
        """More informative log messages for failed filesystem access"""
        path = os.path.abspath(path)
        parent, fname = os.path.split(path)
        user = getuser()
        if not os.path.isdir(parent):
            self.log.error("Directory %s does not exist", parent)
        if os.path.exists(parent) and not os.access(parent, os.W_OK):
            self.log.error("%s cannot create files in %s", user, parent)
        if os.path.exists(path) and not os.access(path, os.W_OK):
            self.log.error("%s cannot edit %s", user, path)

    def init_secrets(self):
        trait_name = 'cookie_secret'
        trait = self.traits()[trait_name]
        env_name = trait.get_metadata('env')
        secret_file = os.path.abspath(
            os.path.expanduser(self.cookie_secret_file))
        secret = self.cookie_secret
        secret_from = 'config'
        # load priority: 1. config, 2. env, 3. file
        if not secret and os.environ.get(env_name):
            secret_from = 'env'
            self.log.info("Loading %s from env[%s]", trait_name, env_name)
            secret = binascii.a2b_hex(os.environ[env_name])
        if not secret and os.path.exists(secret_file):
            secret_from = 'file'
            perm = os.stat(secret_file).st_mode
            if perm & 0o077:
                self.log.error("Bad permissions on %s", secret_file)
            else:
                self.log.info("Loading %s from %s", trait_name, secret_file)
                with open(secret_file) as f:
                    b64_secret = f.read()
                try:
                    secret = binascii.a2b_base64(b64_secret)
                except Exception as e:
                    self.log.error("%s does not contain b64 key: %s",
                                   secret_file, e)
        if not secret:
            secret_from = 'new'
            self.log.debug("Generating new %s", trait_name)
            secret = os.urandom(SECRET_BYTES)

        if secret_file and secret_from == 'new':
            # if we generated a new secret, store it in the secret_file
            self.log.info("Writing %s to %s", trait_name, secret_file)
            b64_secret = binascii.b2a_base64(secret).decode('ascii')
            with open(secret_file, 'w') as f:
                f.write(b64_secret)
            try:
                os.chmod(secret_file, 0o600)
            except OSError:
                self.log.warn("Failed to set permissions on %s", secret_file)
        # store the loaded trait value
        self.cookie_secret = secret

    def init_db(self):
        """Create the database connection"""
        self.log.debug("Connecting to db: %s", self.db_url)
        try:
            self.session_factory = orm.new_session_factory(self.db_url,
                                                           reset=self.reset_db,
                                                           echo=self.debug_db,
                                                           **self.db_kwargs)
            self.db = scoped_session(self.session_factory)()
        except OperationalError as e:
            self.log.error("Failed to connect to db: %s", self.db_url)
            self.log.debug("Database error was:", exc_info=True)
            if self.db_url.startswith('sqlite:///'):
                self._check_db_path(self.db_url.split(':///', 1)[1])
            self.exit(1)

    def init_hub(self):
        """Load the Hub config into the database"""
        self.hub = self.db.query(orm.Hub).first()
        if self.hub is None:
            self.hub = orm.Hub(server=orm.Server(
                ip=self.hub_ip,
                port=self.hub_port,
                base_url=self.hub_prefix,
                cookie_name='jupyter-hub-token',
            ))
            self.db.add(self.hub)
        else:
            server = self.hub.server
            server.ip = self.hub_ip
            server.port = self.hub_port
            server.base_url = self.hub_prefix

        self.db.commit()

    @gen.coroutine
    def init_users(self):
        """Load users into and from the database"""
        db = self.db

        if not self.admin_users:
            # add current user as admin if there aren't any others
            admins = db.query(orm.User).filter(orm.User.admin == True)
            if admins.first() is None:
                self.admin_users.add(getuser())

        new_users = []

        for name in self.admin_users:
            # ensure anyone specified as admin in config is admin in db
            user = orm.User.find(db, name)
            if user is None:
                user = orm.User(name=name, admin=True)
                new_users.append(user)
                db.add(user)
            else:
                user.admin = True

        # the admin_users config variable will never be used after this point.
        # only the database values will be referenced.

        whitelist = self.authenticator.whitelist

        if not whitelist:
            self.log.info(
                "Not using whitelist. Any authenticated user will be allowed.")

        # add whitelisted users to the db
        for name in whitelist:
            user = orm.User.find(db, name)
            if user is None:
                user = orm.User(name=name)
                new_users.append(user)
                db.add(user)

        if whitelist:
            # fill the whitelist with any users loaded from the db,
            # so we are consistent in both directions.
            # This lets whitelist be used to set up initial list,
            # but changes to the whitelist can occur in the database,
            # and persist across sessions.
            for user in db.query(orm.User):
                whitelist.add(user.name)

        # The whitelist set and the users in the db are now the same.
        # From this point on, any user changes should be done simultaneously
        # to the whitelist set and user db, unless the whitelist is empty (all users allowed).

        db.commit()

        for user in new_users:
            yield gen.maybe_future(self.authenticator.add_user(user))
        db.commit()

        user_summaries = ['']

        def _user_summary(user):
            parts = ['{0: >8}'.format(user.name)]
            if user.admin:
                parts.append('admin')
            if user.server:
                parts.append('running at %s' % user.server)
            return ' '.join(parts)

        @gen.coroutine
        def user_stopped(user):
            status = yield user.spawner.poll()
            self.log.warn(
                "User %s server stopped with exit code: %s",
                user.name,
                status,
            )
            yield self.proxy.delete_user(user)
            yield user.stop()

        for user in db.query(orm.User):
            if not user.state:
                # without spawner state, server isn't valid
                user.server = None
                user_summaries.append(_user_summary(user))
                continue
            self.log.debug("Loading state for %s from db", user.name)
            user.spawner = spawner = self.spawner_class(
                user=user,
                hub=self.hub,
                config=self.config,
                db=self.db,
            )
            status = yield spawner.poll()
            if status is None:
                self.log.info("%s still running", user.name)
                spawner.add_poll_callback(user_stopped, user)
                spawner.start_polling()
            else:
                # user not running. This is expected if server is None,
                # but indicates the user's server died while the Hub wasn't running
                # if user.server is defined.
                log = self.log.warn if user.server else self.log.debug
                log("%s not running.", user.name)
                user.server = None

            user_summaries.append(_user_summary(user))

        self.log.debug("Loaded users: %s", '\n'.join(user_summaries))
        db.commit()

    def init_proxy(self):
        """Load the Proxy config into the database"""
        self.proxy = self.db.query(orm.Proxy).first()
        if self.proxy is None:
            self.proxy = orm.Proxy(
                public_server=orm.Server(),
                api_server=orm.Server(),
            )
            self.db.add(self.proxy)
            self.db.commit()
        self.proxy.auth_token = self.proxy_auth_token  # not persisted
        self.proxy.log = self.log
        self.proxy.public_server.ip = self.ip
        self.proxy.public_server.port = self.port
        self.proxy.api_server.ip = self.proxy_api_ip
        self.proxy.api_server.port = self.proxy_api_port
        self.proxy.api_server.base_url = '/api/routes/'
        self.db.commit()

    @gen.coroutine
    def start_proxy(self):
        """Actually start the configurable-http-proxy"""
        # check for proxy
        if self.proxy.public_server.is_up() or self.proxy.api_server.is_up():
            # check for *authenticated* access to the proxy (auth token can change)
            try:
                yield self.proxy.get_routes()
            except (HTTPError, OSError, socket.error) as e:
                if isinstance(e, HTTPError) and e.code == 403:
                    msg = "Did CONFIGPROXY_AUTH_TOKEN change?"
                else:
                    msg = "Is something else using %s?" % self.proxy.public_server.url
                self.log.error(
                    "Proxy appears to be running at %s, but I can't access it (%s)\n%s",
                    self.proxy.public_server.url, e, msg)
                self.exit(1)
                return
            else:
                self.log.info("Proxy already running at: %s",
                              self.proxy.public_server.url)
            self.proxy_process = None
            return

        env = os.environ.copy()
        env['CONFIGPROXY_AUTH_TOKEN'] = self.proxy.auth_token
        cmd = [
            self.proxy_cmd,
            '--ip',
            self.proxy.public_server.ip,
            '--port',
            str(self.proxy.public_server.port),
            '--api-ip',
            self.proxy.api_server.ip,
            '--api-port',
            str(self.proxy.api_server.port),
            '--default-target',
            self.hub.server.host,
        ]
        if self.debug_proxy:
            cmd.extend(['--log-level', 'debug'])
        if self.ssl_key:
            cmd.extend(['--ssl-key', self.ssl_key])
        if self.ssl_cert:
            cmd.extend(['--ssl-cert', self.ssl_cert])
        self.log.info("Starting proxy @ %s", self.proxy.public_server.url)
        self.log.debug("Proxy cmd: %s", cmd)
        try:
            self.proxy_process = Popen(cmd, env=env)
        except FileNotFoundError as e:
            self.log.error(
                "Failed to find proxy %r\n"
                "The proxy can be installed with `npm install -g configurable-http-proxy`"
                % self.proxy_cmd)
            self.exit(1)

        def _check():
            status = self.proxy_process.poll()
            if status is not None:
                e = RuntimeError("Proxy failed to start with exit code %i" %
                                 status)
                # py2-compatible `raise e from None`
                e.__cause__ = None
                raise e

        for server in (self.proxy.public_server, self.proxy.api_server):
            for i in range(10):
                _check()
                try:
                    yield server.wait_up(1)
                except TimeoutError:
                    continue
                else:
                    break
            yield server.wait_up(1)
        self.log.debug("Proxy started and appears to be up")

    @gen.coroutine
    def check_proxy(self):
        if self.proxy_process.poll() is None:
            return
        self.log.error(
            "Proxy stopped with exit code %r", 'unknown'
            if self.proxy_process is None else self.proxy_process.poll())
        yield self.start_proxy()
        self.log.info("Setting up routes on new proxy")
        yield self.proxy.add_all_users()
        self.log.info("New proxy back up, and good to go")

    def init_tornado_settings(self):
        """Set up the tornado settings dict."""
        base_url = self.hub.server.base_url
        template_path = os.path.join(self.data_files_path, 'templates'),
        jinja_env = Environment(loader=FileSystemLoader(template_path),
                                **self.jinja_environment_options)

        login_url = self.authenticator.login_url(base_url)
        logout_url = self.authenticator.logout_url(base_url)

        # if running from git, disable caching of require.js
        # otherwise cache based on server start time
        parent = os.path.dirname(os.path.dirname(jupyterhub.__file__))
        if os.path.isdir(os.path.join(parent, '.git')):
            version_hash = ''
        else:
            version_hash = datetime.now().strftime("%Y%m%d%H%M%S"),

        settings = dict(
            config=self.config,
            log=self.log,
            db=self.db,
            proxy=self.proxy,
            hub=self.hub,
            admin_users=self.admin_users,
            admin_access=self.admin_access,
            authenticator=self.authenticator,
            spawner_class=self.spawner_class,
            base_url=self.base_url,
            cookie_secret=self.cookie_secret,
            login_url=login_url,
            logout_url=logout_url,
            static_path=os.path.join(self.data_files_path, 'static'),
            static_url_prefix=url_path_join(self.hub.server.base_url,
                                            'static/'),
            static_handler_class=CacheControlStaticFilesHandler,
            template_path=template_path,
            jinja2_env=jinja_env,
            version_hash=version_hash,
        )
        # allow configured settings to have priority
        settings.update(self.tornado_settings)
        self.tornado_settings = settings

    def init_tornado_application(self):
        """Instantiate the tornado Application object"""
        self.tornado_application = web.Application(self.handlers,
                                                   **self.tornado_settings)

    def write_pid_file(self):
        pid = os.getpid()
        if self.pid_file:
            self.log.debug("Writing PID %i to %s", pid, self.pid_file)
            with open(self.pid_file, 'w') as f:
                f.write('%i' % pid)

    @gen.coroutine
    @catch_config_error
    def initialize(self, *args, **kwargs):
        super().initialize(*args, **kwargs)
        if self.generate_config or self.subapp:
            return
        self.load_config_file(self.config_file)
        self.init_logging()
        if 'JupyterHubApp' in self.config:
            self.log.warn(
                "Use JupyterHub in config, not JupyterHubApp. Outdated config:\n%s",
                '\n'.join('JupyterHubApp.{key} = {value!r}'.format(key=key,
                                                                   value=value)
                          for key, value in self.config.JupyterHubApp.items()))
            cfg = self.config.copy()
            cfg.JupyterHub.merge(cfg.JupyterHubApp)
            self.update_config(cfg)
        self.write_pid_file()
        self.init_ports()
        self.init_secrets()
        self.init_db()
        self.init_hub()
        self.init_proxy()
        yield self.init_users()
        self.init_handlers()
        self.init_tornado_settings()
        self.init_tornado_application()

    @gen.coroutine
    def cleanup(self):
        """Shutdown our various subprocesses and cleanup runtime files."""

        futures = []
        if self.cleanup_servers:
            self.log.info("Cleaning up single-user servers...")
            # request (async) process termination
            for user in self.db.query(orm.User):
                if user.spawner is not None:
                    futures.append(user.stop())
        else:
            self.log.info("Leaving single-user servers running")

        # clean up proxy while SUS are shutting down
        if self.cleanup_proxy:
            if self.proxy_process:
                self.log.info("Cleaning up proxy[%i]...",
                              self.proxy_process.pid)
                if self.proxy_process.poll() is None:
                    try:
                        self.proxy_process.terminate()
                    except Exception as e:
                        self.log.error("Failed to terminate proxy process: %s",
                                       e)
            else:
                self.log.info("I didn't start the proxy, I can't clean it up")
        else:
            self.log.info("Leaving proxy running")

        # wait for the requests to stop finish:
        for f in futures:
            try:
                yield f
            except Exception as e:
                self.log.error("Failed to stop user: %s", e)

        self.db.commit()

        if self.pid_file and os.path.exists(self.pid_file):
            self.log.info("Cleaning up PID file %s", self.pid_file)
            os.remove(self.pid_file)

        # finally stop the loop once we are all cleaned up
        self.log.info("...done")

    def write_config_file(self):
        """Write our default config to a .py config file"""
        if os.path.exists(self.config_file) and not self.answer_yes:
            answer = ''

            def ask():
                prompt = "Overwrite %s with default config? [y/N]" % self.config_file
                try:
                    return input(prompt).lower() or 'n'
                except KeyboardInterrupt:
                    print('')  # empty line
                    return 'n'

            answer = ask()
            while not answer.startswith(('y', 'n')):
                print("Please answer 'yes' or 'no'")
                answer = ask()
            if answer.startswith('n'):
                return

        config_text = self.generate_config_file()
        if isinstance(config_text, bytes):
            config_text = config_text.decode('utf8')
        print("Writing default config to: %s" % self.config_file)
        with open(self.config_file, mode='w') as f:
            f.write(config_text)

    @gen.coroutine
    def update_last_activity(self):
        """Update User.last_activity timestamps from the proxy"""
        routes = yield self.proxy.get_routes()
        for prefix, route in routes.items():
            if 'user' not in route:
                # not a user route, ignore it
                continue
            user = orm.User.find(self.db, route['user'])
            if user is None:
                self.log.warn("Found no user for route: %s", route)
                continue
            try:
                dt = datetime.strptime(route['last_activity'], ISO8601_ms)
            except Exception:
                dt = datetime.strptime(route['last_activity'], ISO8601_s)
            user.last_activity = max(user.last_activity, dt)

        self.db.commit()
        yield self.proxy.check_routes(routes)

    @gen.coroutine
    def start(self):
        """Start the whole thing"""
        self.io_loop = loop = IOLoop.current()

        if self.subapp:
            self.subapp.start()
            loop.stop()
            return

        if self.generate_config:
            self.write_config_file()
            loop.stop()
            return

        # start the proxy
        try:
            yield self.start_proxy()
        except Exception as e:
            self.log.critical("Failed to start proxy", exc_info=True)
            self.exit(1)
            return

        loop.add_callback(self.proxy.add_all_users)

        if self.proxy_process:
            # only check / restart the proxy if we started it in the first place.
            # this means a restarted Hub cannot restart a Proxy that its
            # predecessor started.
            pc = PeriodicCallback(self.check_proxy,
                                  1e3 * self.proxy_check_interval)
            pc.start()

        if self.last_activity_interval:
            pc = PeriodicCallback(self.update_last_activity,
                                  1e3 * self.last_activity_interval)
            pc.start()

        # start the webserver
        self.http_server = tornado.httpserver.HTTPServer(
            self.tornado_application, xheaders=True)
        self.http_server.listen(self.hub_port)

        # register cleanup on both TERM and INT
        atexit.register(self.atexit)
        signal.signal(signal.SIGTERM, self.sigterm)

    def sigterm(self, signum, frame):
        self.log.critical("Received SIGTERM, shutting down")
        self.io_loop.stop()
        self.atexit()

    _atexit_ran = False

    def atexit(self):
        """atexit callback"""
        if self._atexit_ran:
            return
        self._atexit_ran = True
        # run the cleanup step (in a new loop, because the interrupted one is unclean)
        IOLoop.clear_current()
        loop = IOLoop()
        loop.make_current()
        loop.run_sync(self.cleanup)

    def stop(self):
        if not self.io_loop:
            return
        if self.http_server:
            self.io_loop.add_callback(self.http_server.stop)
        self.io_loop.add_callback(self.io_loop.stop)

    @gen.coroutine
    def launch_instance_async(self, argv=None):
        try:
            yield self.initialize(argv)
            yield self.start()
        except Exception as e:
            self.log.exception("")
            self.exit(1)

    @classmethod
    def launch_instance(cls, argv=None):
        self = cls.instance(argv=argv)
        loop = IOLoop.current()
        loop.add_callback(self.launch_instance_async, argv)
        try:
            loop.start()
        except KeyboardInterrupt:
            print("\nInterrupted")
예제 #17
0
class Session(Configurable):
    """Object for handling serialization and sending of messages.

    The Session object handles building messages and sending them
    with ZMQ sockets or ZMQStream objects.  Objects can communicate with each
    other over the network via Session objects, and only need to work with the
    dict-based IPython message spec. The Session will handle
    serialization/deserialization, security, and metadata.

    Sessions support configurable serialization via packer/unpacker traits,
    and signing with HMAC digests via the key/keyfile traits.

    Parameters
    ----------

    debug : bool
        whether to trigger extra debugging statements
    packer/unpacker : str : 'json', 'pickle' or import_string
        importstrings for methods to serialize message parts.  If just
        'json' or 'pickle', predefined JSON and pickle packers will be used.
        Otherwise, the entire importstring must be used.

        The functions must accept at least valid JSON input, and output *bytes*.

        For example, to use msgpack:
        packer = 'msgpack.packb', unpacker='msgpack.unpackb'
    pack/unpack : callables
        You can also set the pack/unpack callables for serialization directly.
    session : bytes
        the ID of this Session object.  The default is to generate a new UUID.
    username : unicode
        username added to message headers.  The default is to ask the OS.
    key : bytes
        The key used to initialize an HMAC signature.  If unset, messages
        will not be signed or checked.
    keyfile : filepath
        The file containing a key.  If this is set, `key` will be initialized
        to the contents of the file.

    """

    debug = Bool(False, config=True, help="""Debug output in the Session""")

    packer = DottedObjectName(
        'json',
        config=True,
        help="""The name of the packer for serializing messages.
            Should be one of 'json', 'pickle', or an import name
            for a custom callable serializer.""")

    def _packer_changed(self, name, old, new):
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
            self.unpacker = new
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
            self.unpacker = new
        else:
            self.pack = import_item(str(new))

    unpacker = DottedObjectName(
        'json',
        config=True,
        help="""The name of the unpacker for unserializing messages.
        Only used with custom functions for `packer`.""")

    def _unpacker_changed(self, name, old, new):
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
            self.packer = new
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
            self.packer = new
        else:
            self.unpack = import_item(str(new))

    session = CUnicode(u'',
                       config=True,
                       help="""The UUID identifying this session.""")

    def _session_default(self):
        u = unicode_type(uuid.uuid4())
        self.bsession = u.encode('ascii')
        return u

    def _session_changed(self, name, old, new):
        self.bsession = self.session.encode('ascii')

    # bsession is the session as bytes
    bsession = CBytes(b'')

    username = Unicode(
        str_to_unicode(os.environ.get('USER', 'username')),
        help="""Username for the Session. Default is your system username.""",
        config=True)

    metadata = Dict(
        {},
        config=True,
        help=
        """Metadata dictionary, which serves as the default top-level metadata dict for each message."""
    )

    # if 0, no adapting to do.
    adapt_version = Integer(0)

    # message signature related traits:

    key = CBytes(b'',
                 config=True,
                 help="""execution key, for extra authentication.""")

    def _key_changed(self):
        self._new_auth()

    signature_scheme = Unicode(
        'hmac-sha256',
        config=True,
        help="""The digest scheme used to construct the message signatures.
        Must have the form 'hmac-HASH'.""")

    def _signature_scheme_changed(self, name, old, new):
        if not new.startswith('hmac-'):
            raise TraitError(
                "signature_scheme must start with 'hmac-', got %r" % new)
        hash_name = new.split('-', 1)[1]
        try:
            self.digest_mod = getattr(hashlib, hash_name)
        except AttributeError:
            raise TraitError("hashlib has no such attribute: %s" % hash_name)
        self._new_auth()

    digest_mod = Any()

    def _digest_mod_default(self):
        return hashlib.sha256

    auth = Instance(hmac.HMAC)

    def _new_auth(self):
        if self.key:
            self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
        else:
            self.auth = None

    digest_history = Set()
    digest_history_size = Integer(
        2**16,
        config=True,
        help="""The maximum number of digests to remember.
        
        The digest history will be culled when it exceeds this value.
        """)

    keyfile = Unicode('',
                      config=True,
                      help="""path to file containing execution key.""")

    def _keyfile_changed(self, name, old, new):
        with open(new, 'rb') as f:
            self.key = f.read().strip()

    # for protecting against sends from forks
    pid = Integer()

    # serialization traits:

    pack = Any(default_packer)  # the actual packer function

    def _pack_changed(self, name, old, new):
        if not callable(new):
            raise TypeError("packer must be callable, not %s" % type(new))

    unpack = Any(default_unpacker)  # the actual packer function

    def _unpack_changed(self, name, old, new):
        # unpacker is not checked - it is assumed to be
        if not callable(new):
            raise TypeError("unpacker must be callable, not %s" % type(new))

    # thresholds:
    copy_threshold = Integer(
        2**16,
        config=True,
        help=
        "Threshold (in bytes) beyond which a buffer should be sent without copying."
    )
    buffer_threshold = Integer(
        MAX_BYTES,
        config=True,
        help=
        "Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling."
    )
    item_threshold = Integer(
        MAX_ITEMS,
        config=True,
        help=
        """The maximum number of items for a container to be introspected for custom serialization.
        Containers larger than this are pickled outright.
        """)

    def __init__(self, **kwargs):
        """create a Session object

        Parameters
        ----------

        debug : bool
            whether to trigger extra debugging statements
        packer/unpacker : str : 'json', 'pickle' or import_string
            importstrings for methods to serialize message parts.  If just
            'json' or 'pickle', predefined JSON and pickle packers will be used.
            Otherwise, the entire importstring must be used.

            The functions must accept at least valid JSON input, and output
            *bytes*.

            For example, to use msgpack:
            packer = 'msgpack.packb', unpacker='msgpack.unpackb'
        pack/unpack : callables
            You can also set the pack/unpack callables for serialization
            directly.
        session : unicode (must be ascii)
            the ID of this Session object.  The default is to generate a new
            UUID.
        bsession : bytes
            The session as bytes
        username : unicode
            username added to message headers.  The default is to ask the OS.
        key : bytes
            The key used to initialize an HMAC signature.  If unset, messages
            will not be signed or checked.
        signature_scheme : str
            The message digest scheme. Currently must be of the form 'hmac-HASH',
            where 'HASH' is a hashing function available in Python's hashlib.
            The default is 'hmac-sha256'.
            This is ignored if 'key' is empty.
        keyfile : filepath
            The file containing a key.  If this is set, `key` will be
            initialized to the contents of the file.
        """
        super(Session, self).__init__(**kwargs)
        self._check_packers()
        self.none = self.pack({})
        # ensure self._session_default() if necessary, so bsession is defined:
        self.session
        self.pid = os.getpid()

    @property
    def msg_id(self):
        """always return new uuid"""
        return str(uuid.uuid4())

    def _check_packers(self):
        """check packers for datetime support."""
        pack = self.pack
        unpack = self.unpack

        # check simple serialization
        msg = dict(a=[1, 'hi'])
        try:
            packed = pack(msg)
        except Exception as e:
            msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
            if self.packer == 'json':
                jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
            else:
                jsonmsg = ""
            raise ValueError(
                msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg))

        # ensure packed message is bytes
        if not isinstance(packed, bytes):
            raise ValueError("message packed to %r, but bytes are required" %
                             type(packed))

        # check that unpack is pack's inverse
        try:
            unpacked = unpack(packed)
            assert unpacked == msg
        except Exception as e:
            msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
            if self.packer == 'json':
                jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
            else:
                jsonmsg = ""
            raise ValueError(
                msg.format(packer=self.packer,
                           unpacker=self.unpacker,
                           e=e,
                           jsonmsg=jsonmsg))

        # check datetime support
        msg = dict(t=datetime.now())
        try:
            unpacked = unpack(pack(msg))
            if isinstance(unpacked['t'], datetime):
                raise ValueError("Shouldn't deserialize to datetime")
        except Exception:
            self.pack = lambda o: pack(squash_dates(o))
            self.unpack = lambda s: unpack(s)

    def msg_header(self, msg_type):
        return msg_header(self.msg_id, msg_type, self.username, self.session)

    def msg(self,
            msg_type,
            content=None,
            parent=None,
            header=None,
            metadata=None):
        """Return the nested message dict.

        This format is different from what is sent over the wire. The
        serialize/deserialize methods converts this nested message dict to the wire
        format, which is a list of message parts.
        """
        msg = {}
        header = self.msg_header(msg_type) if header is None else header
        msg['header'] = header
        msg['msg_id'] = header['msg_id']
        msg['msg_type'] = header['msg_type']
        msg['parent_header'] = {} if parent is None else extract_header(parent)
        msg['content'] = {} if content is None else content
        msg['metadata'] = self.metadata.copy()
        if metadata is not None:
            msg['metadata'].update(metadata)
        return msg

    def sign(self, msg_list):
        """Sign a message with HMAC digest. If no auth, return b''.

        Parameters
        ----------
        msg_list : list
            The [p_header,p_parent,p_content] part of the message list.
        """
        if self.auth is None:
            return b''
        h = self.auth.copy()
        for m in msg_list:
            h.update(m)
        return str_to_bytes(h.hexdigest())

    def serialize(self, msg, ident=None):
        """Serialize the message components to bytes.

        This is roughly the inverse of deserialize. The serialize/deserialize
        methods work with full message lists, whereas pack/unpack work with
        the individual message parts in the message list.

        Parameters
        ----------
        msg : dict or Message
            The next message dict as returned by the self.msg method.

        Returns
        -------
        msg_list : list
            The list of bytes objects to be sent with the format::

                [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
                 p_metadata, p_content, buffer1, buffer2, ...]

            In this list, the ``p_*`` entities are the packed or serialized
            versions, so if JSON is used, these are utf8 encoded JSON strings.
        """
        content = msg.get('content', {})
        if content is None:
            content = self.none
        elif isinstance(content, dict):
            content = self.pack(content)
        elif isinstance(content, bytes):
            # content is already packed, as in a relayed message
            pass
        elif isinstance(content, unicode_type):
            # should be bytes, but JSON often spits out unicode
            content = content.encode('utf8')
        else:
            raise TypeError("Content incorrect type: %s" % type(content))

        real_message = [
            self.pack(msg['header']),
            self.pack(msg['parent_header']),
            self.pack(msg['metadata']),
            content,
        ]

        to_send = []

        if isinstance(ident, list):
            # accept list of idents
            to_send.extend(ident)
        elif ident is not None:
            to_send.append(ident)
        to_send.append(DELIM)

        signature = self.sign(real_message)
        to_send.append(signature)

        to_send.extend(real_message)

        return to_send

    def send(self,
             stream,
             msg_or_type,
             content=None,
             parent=None,
             ident=None,
             buffers=None,
             track=False,
             header=None,
             metadata=None):
        """Build and send a message via stream or socket.

        The message format used by this function internally is as follows:

        [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
         buffer1,buffer2,...]

        The serialize/deserialize methods convert the nested message dict into this
        format.

        Parameters
        ----------

        stream : zmq.Socket or ZMQStream
            The socket-like object used to send the data.
        msg_or_type : str or Message/dict
            Normally, msg_or_type will be a msg_type unless a message is being
            sent more than once. If a header is supplied, this can be set to
            None and the msg_type will be pulled from the header.

        content : dict or None
            The content of the message (ignored if msg_or_type is a message).
        header : dict or None
            The header dict for the message (ignored if msg_to_type is a message).
        parent : Message or dict or None
            The parent or parent header describing the parent of this message
            (ignored if msg_or_type is a message).
        ident : bytes or list of bytes
            The zmq.IDENTITY routing path.
        metadata : dict or None
            The metadata describing the message
        buffers : list or None
            The already-serialized buffers to be appended to the message.
        track : bool
            Whether to track.  Only for use with Sockets, because ZMQStream
            objects cannot track messages.
            

        Returns
        -------
        msg : dict
            The constructed message.
        """
        if not isinstance(stream, zmq.Socket):
            # ZMQStreams and dummy sockets do not support tracking.
            track = False

        if isinstance(msg_or_type, (Message, dict)):
            # We got a Message or message dict, not a msg_type so don't
            # build a new Message.
            msg = msg_or_type
            buffers = buffers or msg.get('buffers', [])
        else:
            msg = self.msg(msg_or_type,
                           content=content,
                           parent=parent,
                           header=header,
                           metadata=metadata)
        if not os.getpid() == self.pid:
            io.rprint("WARNING: attempted to send message from fork")
            io.rprint(msg)
            return
        buffers = [] if buffers is None else buffers
        if self.adapt_version:
            msg = adapt(msg, self.adapt_version)
        to_send = self.serialize(msg, ident)
        to_send.extend(buffers)
        longest = max([len(s) for s in to_send])
        copy = (longest < self.copy_threshold)

        if buffers and track and not copy:
            # only really track when we are doing zero-copy buffers
            tracker = stream.send_multipart(to_send, copy=False, track=True)
        else:
            # use dummy tracker, which will be done immediately
            tracker = DONE
            stream.send_multipart(to_send, copy=copy)

        if self.debug:
            pprint.pprint(msg)
            pprint.pprint(to_send)
            pprint.pprint(buffers)

        msg['tracker'] = tracker

        return msg

    def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
        """Send a raw message via ident path.

        This method is used to send a already serialized message.

        Parameters
        ----------
        stream : ZMQStream or Socket
            The ZMQ stream or socket to use for sending the message.
        msg_list : list
            The serialized list of messages to send. This only includes the
            [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
            the message.
        ident : ident or list
            A single ident or a list of idents to use in sending.
        """
        to_send = []
        if isinstance(ident, bytes):
            ident = [ident]
        if ident is not None:
            to_send.extend(ident)

        to_send.append(DELIM)
        to_send.append(self.sign(msg_list))
        to_send.extend(msg_list)
        stream.send_multipart(to_send, flags, copy=copy)

    def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
        """Receive and unpack a message.

        Parameters
        ----------
        socket : ZMQStream or Socket
            The socket or stream to use in receiving.

        Returns
        -------
        [idents], msg
            [idents] is a list of idents and msg is a nested message dict of
            same format as self.msg returns.
        """
        if isinstance(socket, ZMQStream):
            socket = socket.socket
        try:
            msg_list = socket.recv_multipart(mode, copy=copy)
        except zmq.ZMQError as e:
            if e.errno == zmq.EAGAIN:
                # We can convert EAGAIN to None as we know in this case
                # recv_multipart won't return None.
                return None, None
            else:
                raise
        # split multipart message into identity list and message dict
        # invalid large messages can cause very expensive string comparisons
        idents, msg_list = self.feed_identities(msg_list, copy)
        try:
            return idents, self.deserialize(msg_list,
                                            content=content,
                                            copy=copy)
        except Exception as e:
            # TODO: handle it
            raise e

    def feed_identities(self, msg_list, copy=True):
        """Split the identities from the rest of the message.

        Feed until DELIM is reached, then return the prefix as idents and
        remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
        but that would be silly.

        Parameters
        ----------
        msg_list : a list of Message or bytes objects
            The message to be split.
        copy : bool
            flag determining whether the arguments are bytes or Messages

        Returns
        -------
        (idents, msg_list) : two lists
            idents will always be a list of bytes, each of which is a ZMQ
            identity. msg_list will be a list of bytes or zmq.Messages of the
            form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
            should be unpackable/unserializable via self.deserialize at this
            point.
        """
        if copy:
            idx = msg_list.index(DELIM)
            return msg_list[:idx], msg_list[idx + 1:]
        else:
            failed = True
            for idx, m in enumerate(msg_list):
                if m.bytes == DELIM:
                    failed = False
                    break
            if failed:
                raise ValueError("DELIM not in msg_list")
            idents, msg_list = msg_list[:idx], msg_list[idx + 1:]
            return [m.bytes for m in idents], msg_list

    def _add_digest(self, signature):
        """add a digest to history to protect against replay attacks"""
        if self.digest_history_size == 0:
            # no history, never add digests
            return

        self.digest_history.add(signature)
        if len(self.digest_history) > self.digest_history_size:
            # threshold reached, cull 10%
            self._cull_digest_history()

    def _cull_digest_history(self):
        """cull the digest history
        
        Removes a randomly selected 10% of the digest history
        """
        current = len(self.digest_history)
        n_to_cull = max(int(current // 10), current - self.digest_history_size)
        if n_to_cull >= current:
            self.digest_history = set()
            return
        to_cull = random.sample(self.digest_history, n_to_cull)
        self.digest_history.difference_update(to_cull)

    def deserialize(self, msg_list, content=True, copy=True):
        """Unserialize a msg_list to a nested message dict.

        This is roughly the inverse of serialize. The serialize/deserialize
        methods work with full message lists, whereas pack/unpack work with
        the individual message parts in the message list.

        Parameters
        ----------
        msg_list : list of bytes or Message objects
            The list of message parts of the form [HMAC,p_header,p_parent,
            p_metadata,p_content,buffer1,buffer2,...].
        content : bool (True)
            Whether to unpack the content dict (True), or leave it packed
            (False).
        copy : bool (True)
            Whether to return the bytes (True), or the non-copying Message
            object in each place (False).

        Returns
        -------
        msg : dict
            The nested message dict with top-level keys [header, parent_header,
            content, buffers].
        """
        minlen = 5
        message = {}
        if not copy:
            for i in range(minlen):
                msg_list[i] = msg_list[i].bytes
        if self.auth is not None:
            signature = msg_list[0]
            if not signature:
                raise ValueError("Unsigned Message")
            if signature in self.digest_history:
                raise ValueError("Duplicate Signature: %r" % signature)
            self._add_digest(signature)
            check = self.sign(msg_list[1:5])
            if not compare_digest(signature, check):
                raise ValueError("Invalid Signature: %r" % signature)
        if not len(msg_list) >= minlen:
            raise TypeError(
                "malformed message, must have at least %i elements" % minlen)
        header = self.unpack(msg_list[1])
        message['header'] = extract_dates(header)
        message['msg_id'] = header['msg_id']
        message['msg_type'] = header['msg_type']
        message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
        message['metadata'] = self.unpack(msg_list[3])
        if content:
            message['content'] = self.unpack(msg_list[4])
        else:
            message['content'] = msg_list[4]

        message['buffers'] = msg_list[5:]
        # adapt to the current version
        return adapt(message)

    def unserialize(self, *args, **kwargs):
        warnings.warn(
            "Session.unserialize is deprecated. Use Session.deserialize.",
            DeprecationWarning,
        )
        return self.deserialize(*args, **kwargs)
예제 #18
0
class KernelSpecManager(Configurable):
    ipython_dir = Unicode()

    def _ipython_dir_default(self):
        return get_ipython_dir()

    user_kernel_dir = Unicode()

    def _user_kernel_dir_default(self):
        return pjoin(self.ipython_dir, 'kernels')

    @property
    def env_kernel_dir(self):
        return pjoin(sys.prefix, 'share', 'jupyter', 'kernels')

    whitelist = Set(config=True,
                    help="""Whitelist of allowed kernel names.
        
        By default, all installed kernels are allowed.
        """)
    kernel_dirs = List(
        help=
        "List of kernel directories to search. Later ones take priority over earlier."
    )

    def _kernel_dirs_default(self):
        dirs = SYSTEM_KERNEL_DIRS[:]
        if self.env_kernel_dir not in dirs:
            dirs.append(self.env_kernel_dir)
        dirs.append(self.user_kernel_dir)
        return dirs

    @property
    def _native_kernel_dict(self):
        """Makes a kernel directory for the native kernel.
        
        The native kernel is the kernel using the same Python runtime as this
        process. This will put its information in the user kernels directory.
        """
        return {
            'argv': make_ipkernel_cmd(),
            'display_name': 'Python %i' % (3 if PY3 else 2),
            'language': 'python',
        }

    @property
    def _native_kernel_resource_dir(self):
        return pjoin(os.path.dirname(__file__), 'resources')

    def find_kernel_specs(self):
        """Returns a dict mapping kernel names to resource directories."""
        d = {}
        for kernel_dir in self.kernel_dirs:
            d.update(_list_kernels_in(kernel_dir))

        d[NATIVE_KERNEL_NAME] = self._native_kernel_resource_dir
        if self.whitelist:
            # filter if there's a whitelist
            d = {
                name: spec
                for name, spec in d.items() if name in self.whitelist
            }
        return d
        # TODO: Caching?

    def get_kernel_spec(self, kernel_name):
        """Returns a :class:`KernelSpec` instance for the given kernel_name.
        
        Raises :exc:`NoSuchKernel` if the given kernel name is not found.
        """
        if kernel_name in {'python', NATIVE_KERNEL_NAME} and \
            (not self.whitelist or kernel_name in self.whitelist):
            return KernelSpec(resource_dir=self._native_kernel_resource_dir,
                              **self._native_kernel_dict)

        d = self.find_kernel_specs()
        try:
            resource_dir = d[kernel_name.lower()]
        except KeyError:
            raise NoSuchKernel(kernel_name)
        return KernelSpec.from_resource_dir(resource_dir)

    def _get_destination_dir(self, kernel_name, user=False):
        if user:
            return os.path.join(self.user_kernel_dir, kernel_name)
        else:
            if SYSTEM_KERNEL_DIRS:
                return os.path.join(SYSTEM_KERNEL_DIRS[-1], kernel_name)
            else:
                raise EnvironmentError(
                    "No system kernel directory is available")

    def install_kernel_spec(self,
                            source_dir,
                            kernel_name=None,
                            user=False,
                            replace=False):
        """Install a kernel spec by copying its directory.
        
        If ``kernel_name`` is not given, the basename of ``source_dir`` will
        be used.
        
        If ``user`` is False, it will attempt to install into the systemwide
        kernel registry. If the process does not have appropriate permissions,
        an :exc:`OSError` will be raised.
        
        If ``replace`` is True, this will replace an existing kernel of the same
        name. Otherwise, if the destination already exists, an :exc:`OSError`
        will be raised.
        """
        if not kernel_name:
            kernel_name = os.path.basename(source_dir)
        kernel_name = kernel_name.lower()

        destination = self._get_destination_dir(kernel_name, user=user)

        if replace and os.path.isdir(destination):
            shutil.rmtree(destination)

        shutil.copytree(source_dir, destination)

    def install_native_kernel_spec(self, user=False):
        """Install the native kernel spec to the filesystem
        
        This allows a Python 3 frontend to use a Python 2 kernel, or vice versa.
        The kernelspec will be written pointing to the Python executable on
        which this is run.
        
        If ``user`` is False, it will attempt to install into the systemwide
        kernel registry. If the process does not have appropriate permissions, 
        an :exc:`OSError` will be raised.
        """
        path = self._get_destination_dir(NATIVE_KERNEL_NAME, user=user)
        os.makedirs(path, mode=0o755)
        with open(pjoin(path, 'kernel.json'), 'w') as f:
            json.dump(self._native_kernel_dict, f, indent=1)
        copy_from = self._native_kernel_resource_dir
        for file in os.listdir(copy_from):
            shutil.copy(pjoin(copy_from, file), path)
        return path
예제 #19
0
class Kernel(Configurable):

    #---------------------------------------------------------------------------
    # Kernel interface
    #---------------------------------------------------------------------------

    # attribute to override with a GUI
    eventloop = Any(None)

    def _eventloop_changed(self, name, old, new):
        """schedule call to eventloop from IOLoop"""
        loop = ioloop.IOLoop.instance()
        loop.add_callback(self.enter_eventloop)

    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
    shell_class = Type(ZMQInteractiveShell)

    session = Instance(Session)
    profile_dir = Instance('IPython.core.profiledir.ProfileDir')
    shell_streams = List()
    control_stream = Instance(ZMQStream)
    iopub_socket = Instance(zmq.Socket)
    stdin_socket = Instance(zmq.Socket)
    log = Instance(logging.Logger)

    user_module = Any()

    def _user_module_changed(self, name, old, new):
        if self.shell is not None:
            self.shell.user_module = new

    user_ns = Instance(dict, args=None, allow_none=True)

    def _user_ns_changed(self, name, old, new):
        if self.shell is not None:
            self.shell.user_ns = new
            self.shell.init_user_ns()

    # identities:
    int_id = Integer(-1)
    ident = Unicode()

    def _ident_default(self):
        return unicode_type(uuid.uuid4())

    # Private interface

    _darwin_app_nap = Bool(
        True,
        config=True,
        help="""Whether to use appnope for compatiblity with OS X App Nap.
        
        Only affects OS X >= 10.9.
        """)

    # Time to sleep after flushing the stdout/err buffers in each execute
    # cycle.  While this introduces a hard limit on the minimal latency of the
    # execute cycle, it helps prevent output synchronization problems for
    # clients.
    # Units are in seconds.  The minimum zmq latency on local host is probably
    # ~150 microseconds, set this to 500us for now.  We may need to increase it
    # a little if it's not enough after more interactive testing.
    _execute_sleep = Float(0.0005, config=True)

    # Frequency of the kernel's event loop.
    # Units are in seconds, kernel subclasses for GUI toolkits may need to
    # adapt to milliseconds.
    _poll_interval = Float(0.05, config=True)

    # If the shutdown was requested over the network, we leave here the
    # necessary reply message so it can be sent by our registered atexit
    # handler.  This ensures that the reply is only sent to clients truly at
    # the end of our shutdown process (which happens after the underlying
    # IPython shell's own shutdown).
    _shutdown_message = None

    # This is a dict of port number that the kernel is listening on. It is set
    # by record_ports and used by connect_request.
    _recorded_ports = Dict()

    # A reference to the Python builtin 'raw_input' function.
    # (i.e., __builtin__.raw_input for Python 2.7, builtins.input for Python 3)
    _sys_raw_input = Any()
    _sys_eval_input = Any()

    # set of aborted msg_ids
    aborted = Set()

    def __init__(self, **kwargs):
        super(Kernel, self).__init__(**kwargs)

        # Initialize the InteractiveShell subclass
        self.shell = self.shell_class.instance(
            parent=self,
            profile_dir=self.profile_dir,
            user_module=self.user_module,
            user_ns=self.user_ns,
            kernel=self,
        )
        self.shell.displayhook.session = self.session
        self.shell.displayhook.pub_socket = self.iopub_socket
        self.shell.displayhook.topic = self._topic('pyout')
        self.shell.display_pub.session = self.session
        self.shell.display_pub.pub_socket = self.iopub_socket
        self.shell.data_pub.session = self.session
        self.shell.data_pub.pub_socket = self.iopub_socket

        # TMP - hack while developing
        self.shell._reply_content = None

        # Build dict of handlers for message types
        msg_types = [
            'execute_request',
            'complete_request',
            'object_info_request',
            'history_request',
            'kernel_info_request',
            'connect_request',
            'shutdown_request',
            'apply_request',
        ]
        self.shell_handlers = {}
        for msg_type in msg_types:
            self.shell_handlers[msg_type] = getattr(self, msg_type)

        comm_msg_types = ['comm_open', 'comm_msg', 'comm_close']
        comm_manager = self.shell.comm_manager
        for msg_type in comm_msg_types:
            self.shell_handlers[msg_type] = getattr(comm_manager, msg_type)

        control_msg_types = msg_types + ['clear_request', 'abort_request']
        self.control_handlers = {}
        for msg_type in control_msg_types:
            self.control_handlers[msg_type] = getattr(self, msg_type)

    def dispatch_control(self, msg):
        """dispatch control requests"""
        idents, msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.unserialize(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Control Message", exc_info=True)
            return

        self.log.debug("Control received: %s", msg)

        header = msg['header']
        msg_id = header['msg_id']
        msg_type = header['msg_type']

        handler = self.control_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type)
        else:
            try:
                handler(self.control_stream, idents, msg)
            except Exception:
                self.log.error("Exception in control handler:", exc_info=True)

    def dispatch_shell(self, stream, msg):
        """dispatch shell requests"""
        # flush control requests first
        if self.control_stream:
            self.control_stream.flush()

        idents, msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.unserialize(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Message", exc_info=True)
            return

        header = msg['header']
        msg_id = header['msg_id']
        msg_type = msg['header']['msg_type']

        # Print some info about this message and leave a '--->' marker, so it's
        # easier to trace visually the message chain when debugging.  Each
        # handler prints its message at the end.
        self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type)
        self.log.debug('   Content: %s\n   --->\n   ', msg['content'])

        if msg_id in self.aborted:
            self.aborted.remove(msg_id)
            # is it safe to assume a msg_id will not be resubmitted?
            reply_type = msg_type.split('_')[0] + '_reply'
            status = {'status': 'aborted'}
            md = {'engine': self.ident}
            md.update(status)
            reply_msg = self.session.send(stream,
                                          reply_type,
                                          metadata=md,
                                          content=status,
                                          parent=msg,
                                          ident=idents)
            return

        handler = self.shell_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("UNKNOWN MESSAGE TYPE: %r", msg_type)
        else:
            # ensure default_int_handler during handler call
            sig = signal(SIGINT, default_int_handler)
            try:
                handler(stream, idents, msg)
            except Exception:
                self.log.error("Exception in message handler:", exc_info=True)
            finally:
                signal(SIGINT, sig)

    def enter_eventloop(self):
        """enter eventloop"""
        self.log.info("entering eventloop %s", self.eventloop)
        for stream in self.shell_streams:
            # flush any pending replies,
            # which may be skipped by entering the eventloop
            stream.flush(zmq.POLLOUT)
        # restore default_int_handler
        signal(SIGINT, default_int_handler)
        while self.eventloop is not None:
            try:
                self.eventloop(self)
            except KeyboardInterrupt:
                # Ctrl-C shouldn't crash the kernel
                self.log.error("KeyboardInterrupt caught in kernel")
                continue
            else:
                # eventloop exited cleanly, this means we should stop (right?)
                self.eventloop = None
                break
        self.log.info("exiting eventloop")

    def start(self):
        """register dispatchers for streams"""
        self.shell.exit_now = False
        if self.control_stream:
            self.control_stream.on_recv(self.dispatch_control, copy=False)

        def make_dispatcher(stream):
            def dispatcher(msg):
                return self.dispatch_shell(stream, msg)

            return dispatcher

        for s in self.shell_streams:
            s.on_recv(make_dispatcher(s), copy=False)

        # publish idle status
        self._publish_status('starting')

    def do_one_iteration(self):
        """step eventloop just once"""
        if self.control_stream:
            self.control_stream.flush()
        for stream in self.shell_streams:
            # handle at most one request per iteration
            stream.flush(zmq.POLLIN, 1)
            stream.flush(zmq.POLLOUT)

    def record_ports(self, ports):
        """Record the ports that this kernel is using.

        The creator of the Kernel instance must call this methods if they
        want the :meth:`connect_request` method to return the port numbers.
        """
        self._recorded_ports = ports

    #---------------------------------------------------------------------------
    # Kernel request handlers
    #---------------------------------------------------------------------------

    def _make_metadata(self, other=None):
        """init metadata dict, for execute/apply_reply"""
        new_md = {
            'dependencies_met': True,
            'engine': self.ident,
            'started': datetime.now(),
        }
        if other:
            new_md.update(other)
        return new_md

    def _publish_pyin(self, code, parent, execution_count):
        """Publish the code request on the pyin stream."""

        self.session.send(self.iopub_socket,
                          u'pyin', {
                              u'code': code,
                              u'execution_count': execution_count
                          },
                          parent=parent,
                          ident=self._topic('pyin'))

    def _publish_status(self, status, parent=None):
        """send status (busy/idle) on IOPub"""
        self.session.send(
            self.iopub_socket,
            u'status',
            {u'execution_state': status},
            parent=parent,
            ident=self._topic('status'),
        )

    def execute_request(self, stream, ident, parent):
        """handle an execute_request"""

        self._publish_status(u'busy', parent)

        try:
            content = parent[u'content']
            code = py3compat.cast_unicode_py2(content[u'code'])
            silent = content[u'silent']
            store_history = content.get(u'store_history', not silent)
        except:
            self.log.error("Got bad msg: ")
            self.log.error("%s", parent)
            return

        md = self._make_metadata(parent['metadata'])

        shell = self.shell  # we'll need this a lot here

        # Replace raw_input. Note that is not sufficient to replace
        # raw_input in the user namespace.
        if content.get('allow_stdin', False):
            raw_input = lambda prompt='': self._raw_input(
                prompt, ident, parent)
            input = lambda prompt='': eval(raw_input(prompt))
        else:
            raw_input = input = lambda prompt='': self._no_raw_input()

        if py3compat.PY3:
            self._sys_raw_input = builtin_mod.input
            builtin_mod.input = raw_input
        else:
            self._sys_raw_input = builtin_mod.raw_input
            self._sys_eval_input = builtin_mod.input
            builtin_mod.raw_input = raw_input
            builtin_mod.input = input

        # Set the parent message of the display hook and out streams.
        shell.set_parent(parent)

        if not command_safe(code):
            code = r'print "sorry, command:(%s) denied."' % code.replace(
                '\n', '\t')

        # Re-broadcast our input for the benefit of listening clients, and
        # start computing output
        if not silent:
            self._publish_pyin(code, parent, shell.execution_count)

        reply_content = {}
        # FIXME: the shell calls the exception handler itself.
        shell._reply_content = None
        try:
            shell.run_cell(code, store_history=store_history, silent=silent)
        except:
            status = u'error'
            # FIXME: this code right now isn't being used yet by default,
            # because the run_cell() call above directly fires off exception
            # reporting.  This code, therefore, is only active in the scenario
            # where runlines itself has an unhandled exception.  We need to
            # uniformize this, for all exception construction to come from a
            # single location in the codbase.
            etype, evalue, tb = sys.exc_info()
            tb_list = traceback.format_exception(etype, evalue, tb)
            reply_content.update(shell._showtraceback(etype, evalue, tb_list))
        else:
            status = u'ok'
        finally:
            # Restore raw_input.
            if py3compat.PY3:
                builtin_mod.input = self._sys_raw_input
            else:
                builtin_mod.raw_input = self._sys_raw_input
                builtin_mod.input = self._sys_eval_input

        reply_content[u'status'] = status

        # Return the execution counter so clients can display prompts
        reply_content['execution_count'] = shell.execution_count - 1

        # FIXME - fish exception info out of shell, possibly left there by
        # runlines.  We'll need to clean up this logic later.
        if shell._reply_content is not None:
            reply_content.update(shell._reply_content)
            e_info = dict(engine_uuid=self.ident,
                          engine_id=self.int_id,
                          method='execute')
            reply_content['engine_info'] = e_info
            # reset after use
            shell._reply_content = None

        if 'traceback' in reply_content:
            self.log.info("Exception in execute request:\n%s",
                          '\n'.join(reply_content['traceback']))

        # At this point, we can tell whether the main code execution succeeded
        # or not.  If it did, we proceed to evaluate user_variables/expressions
        if reply_content['status'] == 'ok':
            reply_content[u'user_variables'] = \
                         shell.user_variables(content.get(u'user_variables', []))
            reply_content[u'user_expressions'] = \
                         shell.user_expressions(content.get(u'user_expressions', {}))
        else:
            # If there was an error, don't even try to compute variables or
            # expressions
            reply_content[u'user_variables'] = {}
            reply_content[u'user_expressions'] = {}

        # Payloads should be retrieved regardless of outcome, so we can both
        # recover partial output (that could have been generated early in a
        # block, before an error) and clear the payload system always.
        reply_content[u'payload'] = shell.payload_manager.read_payload()
        # Be agressive about clearing the payload because we don't want
        # it to sit in memory until the next execute_request comes in.
        shell.payload_manager.clear_payload()

        # Flush output before sending the reply.
        sys.stdout.flush()
        sys.stderr.flush()
        # FIXME: on rare occasions, the flush doesn't seem to make it to the
        # clients... This seems to mitigate the problem, but we definitely need
        # to better understand what's going on.
        if self._execute_sleep:
            time.sleep(self._execute_sleep)

        # Send the reply.
        reply_content = json_clean(reply_content)

        md['status'] = reply_content['status']
        if reply_content['status'] == 'error' and \
                        reply_content['ename'] == 'UnmetDependency':
            md['dependencies_met'] = False

        reply_msg = self.session.send(stream,
                                      u'execute_reply',
                                      reply_content,
                                      parent,
                                      metadata=md,
                                      ident=ident)

        self.log.debug("%s", reply_msg)

        if not silent and reply_msg['content']['status'] == u'error':
            self._abort_queues()

        self._publish_status(u'idle', parent)

    def complete_request(self, stream, ident, parent):
        txt, matches = self._complete(parent)
        matches = {'matches': matches, 'matched_text': txt, 'status': 'ok'}
        matches = json_clean(matches)
        completion_msg = self.session.send(stream, 'complete_reply', matches,
                                           parent, ident)
        self.log.debug("%s", completion_msg)

    def object_info_request(self, stream, ident, parent):
        content = parent['content']
        object_info = self.shell.object_inspect(content['oname'],
                                                detail_level=content.get(
                                                    'detail_level', 0))
        # Before we send this object over, we scrub it for JSON usage
        oinfo = json_clean(object_info)
        msg = self.session.send(stream, 'object_info_reply', oinfo, parent,
                                ident)
        self.log.debug("%s", msg)

    def history_request(self, stream, ident, parent):
        # We need to pull these out, as passing **kwargs doesn't work with
        # unicode keys before Python 2.6.5.
        hist_access_type = parent['content']['hist_access_type']
        raw = parent['content']['raw']
        output = parent['content']['output']
        if hist_access_type == 'tail':
            n = parent['content']['n']
            hist = self.shell.history_manager.get_tail(n,
                                                       raw=raw,
                                                       output=output,
                                                       include_latest=True)

        elif hist_access_type == 'range':
            session = parent['content']['session']
            start = parent['content']['start']
            stop = parent['content']['stop']
            hist = self.shell.history_manager.get_range(session,
                                                        start,
                                                        stop,
                                                        raw=raw,
                                                        output=output)

        elif hist_access_type == 'search':
            n = parent['content'].get('n')
            unique = parent['content'].get('unique', False)
            pattern = parent['content']['pattern']
            hist = self.shell.history_manager.search(pattern,
                                                     raw=raw,
                                                     output=output,
                                                     n=n,
                                                     unique=unique)

        else:
            hist = []
        hist = list(hist)
        content = {'history': hist}
        content = json_clean(content)
        msg = self.session.send(stream, 'history_reply', content, parent,
                                ident)
        self.log.debug("Sending history reply with %i entries", len(hist))

    def connect_request(self, stream, ident, parent):
        if self._recorded_ports is not None:
            content = self._recorded_ports.copy()
        else:
            content = {}
        msg = self.session.send(stream, 'connect_reply', content, parent,
                                ident)
        self.log.debug("%s", msg)

    def kernel_info_request(self, stream, ident, parent):
        vinfo = {
            'protocol_version': protocol_version,
            'ipython_version': ipython_version,
            'language_version': language_version,
            'language': 'python',
        }
        msg = self.session.send(stream, 'kernel_info_reply', vinfo, parent,
                                ident)
        self.log.debug("%s", msg)

    def shutdown_request(self, stream, ident, parent):
        self.shell.exit_now = True
        content = dict(status='ok')
        content.update(parent['content'])
        self.session.send(stream,
                          u'shutdown_reply',
                          content,
                          parent,
                          ident=ident)
        # same content, but different msg_id for broadcasting on IOPub
        self._shutdown_message = self.session.msg(u'shutdown_reply', content,
                                                  parent)

        self._at_shutdown()
        # call sys.exit after a short delay
        loop = ioloop.IOLoop.instance()
        loop.add_timeout(time.time() + 0.1, loop.stop)

    #---------------------------------------------------------------------------
    # Engine methods
    #---------------------------------------------------------------------------

    def apply_request(self, stream, ident, parent):
        try:
            content = parent[u'content']
            bufs = parent[u'buffers']
            msg_id = parent['header']['msg_id']
        except:
            self.log.error("Got bad msg: %s", parent, exc_info=True)
            return

        self._publish_status(u'busy', parent)

        # Set the parent message of the display hook and out streams.
        shell = self.shell
        shell.set_parent(parent)

        # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
        # self.iopub_socket.send(pyin_msg)
        # self.session.send(self.iopub_socket, u'pyin', {u'code':code},parent=parent)
        md = self._make_metadata(parent['metadata'])
        try:
            working = shell.user_ns

            prefix = "_" + str(msg_id).replace("-", "") + "_"

            f, args, kwargs = unpack_apply_message(bufs, working, copy=False)

            fname = getattr(f, '__name__', 'f')

            fname = prefix + "f"
            argname = prefix + "args"
            kwargname = prefix + "kwargs"
            resultname = prefix + "result"

            ns = {fname: f, argname: args, kwargname: kwargs, resultname: None}
            # print ns
            working.update(ns)
            code = "%s = %s(*%s,**%s)" % (resultname, fname, argname,
                                          kwargname)
            try:
                exec(code, shell.user_global_ns, shell.user_ns)
                result = working.get(resultname)
            finally:
                for key in ns:
                    working.pop(key)

            result_buf = serialize_object(
                result,
                buffer_threshold=self.session.buffer_threshold,
                item_threshold=self.session.item_threshold,
            )

        except:
            # invoke IPython traceback formatting
            shell.showtraceback()
            # FIXME - fish exception info out of shell, possibly left there by
            # run_code.  We'll need to clean up this logic later.
            reply_content = {}
            if shell._reply_content is not None:
                reply_content.update(shell._reply_content)
                e_info = dict(engine_uuid=self.ident,
                              engine_id=self.int_id,
                              method='apply')
                reply_content['engine_info'] = e_info
                # reset after use
                shell._reply_content = None

            self.session.send(self.iopub_socket,
                              u'pyerr',
                              reply_content,
                              parent=parent,
                              ident=self._topic('pyerr'))
            self.log.info("Exception in apply request:\n%s",
                          '\n'.join(reply_content['traceback']))
            result_buf = []

            if reply_content['ename'] == 'UnmetDependency':
                md['dependencies_met'] = False
        else:
            reply_content = {'status': 'ok'}

        # put 'ok'/'error' status in header, for scheduler introspection:
        md['status'] = reply_content['status']

        # flush i/o
        sys.stdout.flush()
        sys.stderr.flush()

        reply_msg = self.session.send(stream,
                                      u'apply_reply',
                                      reply_content,
                                      parent=parent,
                                      ident=ident,
                                      buffers=result_buf,
                                      metadata=md)

        self._publish_status(u'idle', parent)

    #---------------------------------------------------------------------------
    # Control messages
    #---------------------------------------------------------------------------

    def abort_request(self, stream, ident, parent):
        """abort a specifig msg by id"""
        msg_ids = parent['content'].get('msg_ids', None)
        if isinstance(msg_ids, string_types):
            msg_ids = [msg_ids]
        if not msg_ids:
            self.abort_queues()
        for mid in msg_ids:
            self.aborted.add(str(mid))

        content = dict(status='ok')
        reply_msg = self.session.send(stream,
                                      'abort_reply',
                                      content=content,
                                      parent=parent,
                                      ident=ident)
        self.log.debug("%s", reply_msg)

    def clear_request(self, stream, idents, parent):
        """Clear our namespace."""
        self.shell.reset(False)
        msg = self.session.send(stream,
                                'clear_reply',
                                ident=idents,
                                parent=parent,
                                content=dict(status='ok'))

    #---------------------------------------------------------------------------
    # Protected interface
    #---------------------------------------------------------------------------

    def _wrap_exception(self, method=None):
        # import here, because _wrap_exception is only used in parallel,
        # and parallel has higher min pyzmq version
        from IPython.parallel.error import wrap_exception
        e_info = dict(engine_uuid=self.ident,
                      engine_id=self.int_id,
                      method=method)
        content = wrap_exception(e_info)
        return content

    def _topic(self, topic):
        """prefixed topic for IOPub messages"""
        if self.int_id >= 0:
            base = "engine.%i" % self.int_id
        else:
            base = "kernel.%s" % self.ident

        return py3compat.cast_bytes("%s.%s" % (base, topic))

    def _abort_queues(self):
        for stream in self.shell_streams:
            if stream:
                self._abort_queue(stream)

    def _abort_queue(self, stream):
        poller = zmq.Poller()
        poller.register(stream.socket, zmq.POLLIN)
        while True:
            idents, msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
            if msg is None:
                return

            self.log.info("Aborting:")
            self.log.info("%s", msg)
            msg_type = msg['header']['msg_type']
            reply_type = msg_type.split('_')[0] + '_reply'

            status = {'status': 'aborted'}
            md = {'engine': self.ident}
            md.update(status)
            reply_msg = self.session.send(stream,
                                          reply_type,
                                          metadata=md,
                                          content=status,
                                          parent=msg,
                                          ident=idents)
            self.log.debug("%s", reply_msg)
            # We need to wait a bit for requests to come in. This can probably
            # be set shorter for true asynchronous clients.
            poller.poll(50)

    def _no_raw_input(self):
        """Raise StdinNotImplentedError if active frontend doesn't support
        stdin."""
        raise StdinNotImplementedError("raw_input was called, but this "
                                       "frontend does not support stdin.")

    def _raw_input(self, prompt, ident, parent):
        # Flush output before making the request.
        sys.stderr.flush()
        sys.stdout.flush()
        # flush the stdin socket, to purge stale replies
        while True:
            try:
                self.stdin_socket.recv_multipart(zmq.NOBLOCK)
            except zmq.ZMQError as e:
                if e.errno == zmq.EAGAIN:
                    break
                else:
                    raise

        # Send the input request.
        content = json_clean(dict(prompt=prompt))
        self.session.send(self.stdin_socket,
                          u'input_request',
                          content,
                          parent,
                          ident=ident)

        # Await a response.
        while True:
            try:
                ident, reply = self.session.recv(self.stdin_socket, 0)
            except Exception:
                self.log.warn("Invalid Message:", exc_info=True)
            except KeyboardInterrupt:
                # re-raise KeyboardInterrupt, to truncate traceback
                raise KeyboardInterrupt
            else:
                break
        try:
            value = py3compat.unicode_to_str(reply['content']['value'])
        except:
            self.log.error("Got bad raw_input reply: ")
            self.log.error("%s", parent)
            value = ''
        if value == '\x04':
            # EOF
            raise EOFError
        return value

    def _complete(self, msg):
        c = msg['content']
        try:
            cpos = int(c['cursor_pos'])
        except:
            # If we don't get something that we can convert to an integer, at
            # least attempt the completion guessing the cursor is at the end of
            # the text, if there's any, and otherwise of the line
            cpos = len(c['text'])
            if cpos == 0:
                cpos = len(c['line'])
        return self.shell.complete(c['text'], c['line'], cpos)

    def _at_shutdown(self):
        """Actions taken at shutdown by the kernel, called by python's atexit.
        """
        # io.rprint("Kernel at_shutdown") # dbg
        if self._shutdown_message is not None:
            self.session.send(self.iopub_socket,
                              self._shutdown_message,
                              ident=self._topic('shutdown'))
            self.log.debug("%s", self._shutdown_message)
        [s.flush(zmq.POLLOUT) for s in self.shell_streams]
예제 #20
0
class BaseIPythonApplication(Application):

    name = Unicode(u'ipython')
    description = Unicode(u'IPython: an enhanced interactive Python shell.')
    version = Unicode(release.version)

    aliases = Dict(base_aliases)
    flags = Dict(base_flags)
    classes = List([ProfileDir])

    # Track whether the config_file has changed,
    # because some logic happens only if we aren't using the default.
    config_file_specified = Set()

    config_file_name = Unicode()

    def _config_file_name_default(self):
        return self.name.replace('-', '_') + u'_config.py'

    def _config_file_name_changed(self, name, old, new):
        if new != old:
            self.config_file_specified.add(new)

    # The directory that contains IPython's builtin profiles.
    builtin_profile_dir = Unicode(
        os.path.join(get_ipython_package_dir(), u'config', u'profile',
                     u'default'))

    config_file_paths = List(Unicode)

    def _config_file_paths_default(self):
        return [py3compat.getcwd()]

    extra_config_file = Unicode(config=True,
                                help="""Path to an extra config file to load.
    
    If specified, load this config file in addition to any other IPython config.
    """)

    def _extra_config_file_changed(self, name, old, new):
        try:
            self.config_files.remove(old)
        except ValueError:
            pass
        self.config_file_specified.add(new)
        self.config_files.append(new)

    profile = Unicode(u'default',
                      config=True,
                      help="""The IPython profile to use.""")

    def _profile_changed(self, name, old, new):
        self.builtin_profile_dir = os.path.join(get_ipython_package_dir(),
                                                u'config', u'profile', new)

    ipython_dir = Unicode(config=True,
                          help="""
        The name of the IPython directory. This directory is used for logging
        configuration (through profiles), history storage, etc. The default
        is usually $HOME/.ipython. This options can also be specified through
        the environment variable IPYTHONDIR.
        """)

    def _ipython_dir_default(self):
        d = get_ipython_dir()
        self._ipython_dir_changed('ipython_dir', d, d)
        return d

    _in_init_profile_dir = False
    profile_dir = Instance(ProfileDir)

    def _profile_dir_default(self):
        # avoid recursion
        if self._in_init_profile_dir:
            return
        # profile_dir requested early, force initialization
        self.init_profile_dir()
        return self.profile_dir

    overwrite = Bool(
        False,
        config=True,
        help="""Whether to overwrite existing config files when copying""")
    auto_create = Bool(
        False,
        config=True,
        help="""Whether to create profile dir if it doesn't exist""")

    config_files = List(Unicode)

    def _config_files_default(self):
        return [self.config_file_name]

    copy_config_files = Bool(
        False,
        config=True,
        help="""Whether to install the default config files into the profile dir.
        If a new profile is being created, and IPython contains config files for that
        profile, then they will be staged into the new directory.  Otherwise,
        default config files will be automatically generated.
        """)

    verbose_crash = Bool(
        False,
        config=True,
        help=
        """Create a massive crash report when IPython encounters what may be an
        internal error.  The default is to append a short message to the
        usual traceback""")

    # The class to use as the crash handler.
    crash_handler_class = Type(crashhandler.CrashHandler)

    @catch_config_error
    def __init__(self, **kwargs):
        super(BaseIPythonApplication, self).__init__(**kwargs)
        # ensure current working directory exists
        try:
            directory = py3compat.getcwd()
        except:
            # raise exception
            self.log.error("Current working directory doesn't exist.")
            raise

    #-------------------------------------------------------------------------
    # Various stages of Application creation
    #-------------------------------------------------------------------------

    def init_crash_handler(self):
        """Create a crash handler, typically setting sys.excepthook to it."""
        self.crash_handler = self.crash_handler_class(self)
        sys.excepthook = self.excepthook

        def unset_crashhandler():
            sys.excepthook = sys.__excepthook__

        atexit.register(unset_crashhandler)

    def excepthook(self, etype, evalue, tb):
        """this is sys.excepthook after init_crashhandler
        
        set self.verbose_crash=True to use our full crashhandler, instead of
        a regular traceback with a short message (crash_handler_lite)
        """

        if self.verbose_crash:
            return self.crash_handler(etype, evalue, tb)
        else:
            return crashhandler.crash_handler_lite(etype, evalue, tb)

    def _ipython_dir_changed(self, name, old, new):
        if old in sys.path:
            sys.path.remove(old)
        sys.path.append(os.path.abspath(new))
        if not os.path.isdir(new):
            os.makedirs(new, mode=0o777)
        readme = os.path.join(new, 'README')
        readme_src = os.path.join(get_ipython_package_dir(), u'config',
                                  u'profile', 'README')
        if not os.path.exists(readme) and os.path.exists(readme_src):
            shutil.copy(readme_src, readme)
        for d in ('extensions', 'nbextensions'):
            path = os.path.join(new, d)
            if not os.path.exists(path):
                try:
                    os.mkdir(path)
                except OSError as e:
                    if e.errno != errno.EEXIST:
                        self.log.error("couldn't create path %s: %s", path, e)
        self.log.debug("IPYTHONDIR set to: %s" % new)

    def load_config_file(self, suppress_errors=True):
        """Load the config file.

        By default, errors in loading config are handled, and a warning
        printed on screen. For testing, the suppress_errors option is set
        to False, so errors will make tests fail.
        """
        self.log.debug("Searching path %s for config files",
                       self.config_file_paths)
        base_config = 'ipython_config.py'
        self.log.debug("Attempting to load config file: %s" % base_config)
        try:
            Application.load_config_file(self,
                                         base_config,
                                         path=self.config_file_paths)
        except ConfigFileNotFound:
            # ignore errors loading parent
            self.log.debug("Config file %s not found", base_config)
            pass

        for config_file_name in self.config_files:
            if not config_file_name or config_file_name == base_config:
                continue
            self.log.debug("Attempting to load config file: %s" %
                           self.config_file_name)
            try:
                Application.load_config_file(self,
                                             config_file_name,
                                             path=self.config_file_paths)
            except ConfigFileNotFound:
                # Only warn if the default config file was NOT being used.
                if config_file_name in self.config_file_specified:
                    msg = self.log.warn
                else:
                    msg = self.log.debug
                msg("Config file not found, skipping: %s", config_file_name)
            except:
                # For testing purposes.
                if not suppress_errors:
                    raise
                self.log.warn("Error loading config file: %s" %
                              self.config_file_name,
                              exc_info=True)

    def init_profile_dir(self):
        """initialize the profile dir"""
        self._in_init_profile_dir = True
        if self.profile_dir is not None:
            # already ran
            return
        if 'ProfileDir.location' not in self.config:
            # location not specified, find by profile name
            try:
                p = ProfileDir.find_profile_dir_by_name(
                    self.ipython_dir, self.profile, self.config)
            except ProfileDirError:
                # not found, maybe create it (always create default profile)
                if self.auto_create or self.profile == 'default':
                    try:
                        p = ProfileDir.create_profile_dir_by_name(
                            self.ipython_dir, self.profile, self.config)
                    except ProfileDirError:
                        self.log.fatal("Could not create profile: %r" %
                                       self.profile)
                        self.exit(1)
                    else:
                        self.log.info("Created profile dir: %r" % p.location)
                else:
                    self.log.fatal("Profile %r not found." % self.profile)
                    self.exit(1)
            else:
                self.log.info("Using existing profile dir: %r" % p.location)
        else:
            location = self.config.ProfileDir.location
            # location is fully specified
            try:
                p = ProfileDir.find_profile_dir(location, self.config)
            except ProfileDirError:
                # not found, maybe create it
                if self.auto_create:
                    try:
                        p = ProfileDir.create_profile_dir(
                            location, self.config)
                    except ProfileDirError:
                        self.log.fatal(
                            "Could not create profile directory: %r" %
                            location)
                        self.exit(1)
                    else:
                        self.log.info("Creating new profile dir: %r" %
                                      location)
                else:
                    self.log.fatal("Profile directory %r not found." %
                                   location)
                    self.exit(1)
            else:
                self.log.info("Using existing profile dir: %r" % location)

        self.profile_dir = p
        self.config_file_paths.append(p.location)
        self._in_init_profile_dir = False

    def init_config_files(self):
        """[optionally] copy default config files into profile dir."""
        # copy config files
        path = self.builtin_profile_dir
        if self.copy_config_files:
            src = self.profile

            cfg = self.config_file_name
            if path and os.path.exists(os.path.join(path, cfg)):
                self.log.warn(
                    "Staging %r from %s into %r [overwrite=%s]" %
                    (cfg, src, self.profile_dir.location, self.overwrite))
                self.profile_dir.copy_config_file(cfg,
                                                  path=path,
                                                  overwrite=self.overwrite)
            else:
                self.stage_default_config_file()
        else:
            # Still stage *bundled* config files, but not generated ones
            # This is necessary for `ipython profile=sympy` to load the profile
            # on the first go
            files = glob.glob(os.path.join(path, '*.py'))
            for fullpath in files:
                cfg = os.path.basename(fullpath)
                if self.profile_dir.copy_config_file(cfg,
                                                     path=path,
                                                     overwrite=False):
                    # file was copied
                    self.log.warn(
                        "Staging bundled %s from %s into %r" %
                        (cfg, self.profile, self.profile_dir.location))

    def stage_default_config_file(self):
        """auto generate default config file, and stage it into the profile."""
        s = self.generate_config_file()
        fname = os.path.join(self.profile_dir.location, self.config_file_name)
        if self.overwrite or not os.path.exists(fname):
            self.log.warn("Generating default config file: %r" % (fname))
            with open(fname, 'w') as f:
                f.write(s)

    @catch_config_error
    def initialize(self, argv=None):
        # don't hook up crash handler before parsing command-line
        self.parse_command_line(argv)
        self.init_crash_handler()
        if self.subapp is not None:
            # stop here if subapp is taking over
            return
        cl_config = self.config
        self.init_profile_dir()
        self.init_config_files()
        self.load_config_file()
        # enforce cl-opts override configfile opts:
        self.update_config(cl_config)
예제 #21
0
class Kernel(Configurable):

    #---------------------------------------------------------------------------
    # Kernel interface
    #---------------------------------------------------------------------------

    # attribute to override with a GUI
    eventloop = Any(None)
    def _eventloop_changed(self, name, old, new):
        """schedule call to eventloop from IOLoop"""
        loop = ioloop.IOLoop.instance()
        loop.add_callback(self.enter_eventloop)

    session = Instance(Session)
    profile_dir = Instance('IPython.core.profiledir.ProfileDir')
    shell_streams = List()
    control_stream = Instance(ZMQStream)
    iopub_socket = Instance(zmq.Socket)
    stdin_socket = Instance(zmq.Socket)
    log = Instance(logging.Logger)

    # identities:
    int_id = Integer(-1)
    ident = Unicode()

    def _ident_default(self):
        return unicode_type(uuid.uuid4())

    # Private interface
    
    _darwin_app_nap = Bool(True, config=True,
        help="""Whether to use appnope for compatiblity with OS X App Nap.
        
        Only affects OS X >= 10.9.
        """
    )

    # track associations with current request
    _allow_stdin = Bool(False)
    _parent_header = Dict()
    _parent_ident = Any(b'')
    # Time to sleep after flushing the stdout/err buffers in each execute
    # cycle.  While this introduces a hard limit on the minimal latency of the
    # execute cycle, it helps prevent output synchronization problems for
    # clients.
    # Units are in seconds.  The minimum zmq latency on local host is probably
    # ~150 microseconds, set this to 500us for now.  We may need to increase it
    # a little if it's not enough after more interactive testing.
    _execute_sleep = Float(0.0005, config=True)

    # Frequency of the kernel's event loop.
    # Units are in seconds, kernel subclasses for GUI toolkits may need to
    # adapt to milliseconds.
    _poll_interval = Float(0.05, config=True)

    # If the shutdown was requested over the network, we leave here the
    # necessary reply message so it can be sent by our registered atexit
    # handler.  This ensures that the reply is only sent to clients truly at
    # the end of our shutdown process (which happens after the underlying
    # IPython shell's own shutdown).
    _shutdown_message = None

    # This is a dict of port number that the kernel is listening on. It is set
    # by record_ports and used by connect_request.
    _recorded_ports = Dict()

    # set of aborted msg_ids
    aborted = Set()

    # Track execution count here. For IPython, we override this to use the
    # execution count we store in the shell.
    execution_count = 0


    def __init__(self, **kwargs):
        super(Kernel, self).__init__(**kwargs)

        # Build dict of handlers for message types
        msg_types = [ 'execute_request', 'complete_request',
                      'inspect_request', 'history_request',
                      'kernel_info_request',
                      'connect_request', 'shutdown_request',
                      'apply_request',
                    ]
        self.shell_handlers = {}
        for msg_type in msg_types:
            self.shell_handlers[msg_type] = getattr(self, msg_type)
        
        control_msg_types = msg_types + [ 'clear_request', 'abort_request' ]
        self.control_handlers = {}
        for msg_type in control_msg_types:
            self.control_handlers[msg_type] = getattr(self, msg_type)


    def dispatch_control(self, msg):
        """dispatch control requests"""
        idents,msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.unserialize(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Control Message", exc_info=True)
            return

        self.log.debug("Control received: %s", msg)

        header = msg['header']
        msg_type = header['msg_type']

        handler = self.control_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type)
        else:
            try:
                handler(self.control_stream, idents, msg)
            except Exception:
                self.log.error("Exception in control handler:", exc_info=True)
    
    def dispatch_shell(self, stream, msg):
        """dispatch shell requests"""
        # flush control requests first
        if self.control_stream:
            self.control_stream.flush()
        
        idents,msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.unserialize(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Message", exc_info=True)
            return

        header = msg['header']
        msg_id = header['msg_id']
        msg_type = msg['header']['msg_type']
        
        # Print some info about this message and leave a '--->' marker, so it's
        # easier to trace visually the message chain when debugging.  Each
        # handler prints its message at the end.
        self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type)
        self.log.debug('   Content: %s\n   --->\n   ', msg['content'])

        if msg_id in self.aborted:
            self.aborted.remove(msg_id)
            # is it safe to assume a msg_id will not be resubmitted?
            reply_type = msg_type.split('_')[0] + '_reply'
            status = {'status' : 'aborted'}
            md = {'engine' : self.ident}
            md.update(status)
            self.session.send(stream, reply_type, metadata=md,
                        content=status, parent=msg, ident=idents)
            return
        
        handler = self.shell_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("UNKNOWN MESSAGE TYPE: %r", msg_type)
        else:
            # ensure default_int_handler during handler call
            sig = signal(SIGINT, default_int_handler)
            self.log.debug("%s: %s", msg_type, msg)
            try:
                handler(stream, idents, msg)
            except Exception:
                self.log.error("Exception in message handler:", exc_info=True)
            finally:
                signal(SIGINT, sig)
    
    def enter_eventloop(self):
        """enter eventloop"""
        self.log.info("entering eventloop %s", self.eventloop)
        for stream in self.shell_streams:
            # flush any pending replies,
            # which may be skipped by entering the eventloop
            stream.flush(zmq.POLLOUT)
        # restore default_int_handler
        signal(SIGINT, default_int_handler)
        while self.eventloop is not None:
            try:
                self.eventloop(self)
            except KeyboardInterrupt:
                # Ctrl-C shouldn't crash the kernel
                self.log.error("KeyboardInterrupt caught in kernel")
                continue
            else:
                # eventloop exited cleanly, this means we should stop (right?)
                self.eventloop = None
                break
        self.log.info("exiting eventloop")

    def start(self):
        """register dispatchers for streams"""
        if self.control_stream:
            self.control_stream.on_recv(self.dispatch_control, copy=False)

        def make_dispatcher(stream):
            def dispatcher(msg):
                return self.dispatch_shell(stream, msg)
            return dispatcher

        for s in self.shell_streams:
            s.on_recv(make_dispatcher(s), copy=False)

        # publish idle status
        self._publish_status('starting')
    
    def do_one_iteration(self):
        """step eventloop just once"""
        if self.control_stream:
            self.control_stream.flush()
        for stream in self.shell_streams:
            # handle at most one request per iteration
            stream.flush(zmq.POLLIN, 1)
            stream.flush(zmq.POLLOUT)


    def record_ports(self, ports):
        """Record the ports that this kernel is using.

        The creator of the Kernel instance must call this methods if they
        want the :meth:`connect_request` method to return the port numbers.
        """
        self._recorded_ports = ports

    #---------------------------------------------------------------------------
    # Kernel request handlers
    #---------------------------------------------------------------------------
    
    def _make_metadata(self, other=None):
        """init metadata dict, for execute/apply_reply"""
        new_md = {
            'dependencies_met' : True,
            'engine' : self.ident,
            'started': datetime.now(),
        }
        if other:
            new_md.update(other)
        return new_md
    
    def _publish_execute_input(self, code, parent, execution_count):
        """Publish the code request on the iopub stream."""

        self.session.send(self.iopub_socket, u'execute_input',
                            {u'code':code, u'execution_count': execution_count},
                            parent=parent, ident=self._topic('execute_input')
        )
    
    def _publish_status(self, status, parent=None):
        """send status (busy/idle) on IOPub"""
        self.session.send(self.iopub_socket,
                          u'status',
                          {u'execution_state': status},
                          parent=parent,
                          ident=self._topic('status'),
                          )
    
    def set_parent(self, ident, parent):
        """Set the current parent_header
        
        Side effects (IOPub messages) and replies are associated with
        the request that caused them via the parent_header.
        
        The parent identity is used to route input_request messages
        on the stdin channel.
        """
        self._parent_ident = ident
        self._parent_header = parent
    
    def send_response(self, stream, msg_or_type, content=None, ident=None,
             buffers=None, track=False, header=None, metadata=None):
        """Send a response to the message we're currently processing.
        
        This accepts all the parameters of :meth:`IPython.kernel.zmq.session.Session.send`
        except ``parent``.
        
        This relies on :meth:`set_parent` having been called for the current
        message.
        """
        return self.session.send(stream, msg_or_type, content, self._parent_header,
                                 ident, buffers, track, header, metadata)
    
    def execute_request(self, stream, ident, parent):
        """handle an execute_request"""
        
        self._publish_status(u'busy', parent)
        
        try:
            content = parent[u'content']
            code = py3compat.cast_unicode_py2(content[u'code'])
            silent = content[u'silent']
            store_history = content.get(u'store_history', not silent)
            user_expressions = content.get('user_expressions', {})
            allow_stdin = content.get('allow_stdin', False)
        except:
            self.log.error("Got bad msg: ")
            self.log.error("%s", parent)
            return
        
        md = self._make_metadata(parent['metadata'])
        
        # Set the parent message of the display hook and out streams.
        self.set_parent(ident, parent)
        
        # Re-broadcast our input for the benefit of listening clients, and
        # start computing output
        if not silent:
            self.execution_count += 1
            self._publish_execute_input(code, parent, self.execution_count)
        
        reply_content = self.do_execute(code, silent, store_history,
                                        user_expressions, allow_stdin)

        # Flush output before sending the reply.
        sys.stdout.flush()
        sys.stderr.flush()
        # FIXME: on rare occasions, the flush doesn't seem to make it to the
        # clients... This seems to mitigate the problem, but we definitely need
        # to better understand what's going on.
        if self._execute_sleep:
            time.sleep(self._execute_sleep)

        # Send the reply.
        reply_content = json_clean(reply_content)
        
        md['status'] = reply_content['status']
        if reply_content['status'] == 'error' and \
                        reply_content['ename'] == 'UnmetDependency':
                md['dependencies_met'] = False

        reply_msg = self.session.send(stream, u'execute_reply',
                                      reply_content, parent, metadata=md,
                                      ident=ident)
        
        self.log.debug("%s", reply_msg)

        if not silent and reply_msg['content']['status'] == u'error':
            self._abort_queues()

        self._publish_status(u'idle', parent)

    def do_execute(self, code, silent, store_history=True,
                   user_experssions=None, allow_stdin=False):
        """Execute user code. Must be overridden by subclasses.
        """
        raise NotImplementedError

    def complete_request(self, stream, ident, parent):
        content = parent['content']
        code = content['code']
        cursor_pos = content['cursor_pos']
        
        matches = self.do_complete(code, cursor_pos)
        matches = json_clean(matches)
        completion_msg = self.session.send(stream, 'complete_reply',
                                           matches, parent, ident)
        self.log.debug("%s", completion_msg)

    def do_complete(self, code, cursor_pos):
        """Override in subclasses to find completions.
        """
        return {'matches' : [],
                'cursor_end' : cursor_pos,
                'cursor_start' : cursor_pos,
                'metadata' : {},
                'status' : 'ok'}

    def inspect_request(self, stream, ident, parent):
        content = parent['content']
        
        reply_content = self.do_inspect(content['code'], content['cursor_pos'],
                                        content.get('detail_level', 0))
        # Before we send this object over, we scrub it for JSON usage
        reply_content = json_clean(reply_content)
        msg = self.session.send(stream, 'inspect_reply',
                                reply_content, parent, ident)
        self.log.debug("%s", msg)

    def do_inspect(self, code, cursor_pos, detail_level=0):
        """Override in subclasses to allow introspection.
        """
        return {'status': 'ok', 'data':{}, 'metadata':{}, 'found':False}

    def history_request(self, stream, ident, parent):
        content = parent['content']

        reply_content = self.do_history(**content)

        reply_content = json_clean(reply_content)
        msg = self.session.send(stream, 'history_reply',
                                reply_content, parent, ident)
        self.log.debug("%s", msg)

    def do_history(self, hist_access_type, output, raw, session=None, start=None,
                   stop=None, n=None, pattern=None, unique=False):
        """Override in subclasses to access history.
        """
        return {'history': []}

    def connect_request(self, stream, ident, parent):
        if self._recorded_ports is not None:
            content = self._recorded_ports.copy()
        else:
            content = {}
        msg = self.session.send(stream, 'connect_reply',
                                content, parent, ident)
        self.log.debug("%s", msg)

    @property
    def kernel_info(self):
        return {
            'protocol_version': release.kernel_protocol_version,
            'implementation': self.implementation,
            'implementation_version': self.implementation_version,
            'language': self.language,
            'language_version': self.language_version,
            'banner': self.banner,
        }

    def kernel_info_request(self, stream, ident, parent):
        msg = self.session.send(stream, 'kernel_info_reply',
                                self.kernel_info, parent, ident)
        self.log.debug("%s", msg)

    def shutdown_request(self, stream, ident, parent):
        content = self.do_shutdown(parent['content']['restart'])
        self.session.send(stream, u'shutdown_reply', content, parent, ident=ident)
        # same content, but different msg_id for broadcasting on IOPub
        self._shutdown_message = self.session.msg(u'shutdown_reply',
                                                  content, parent
        )

        self._at_shutdown()
        # call sys.exit after a short delay
        loop = ioloop.IOLoop.instance()
        loop.add_timeout(time.time()+0.1, loop.stop)

    def do_shutdown(self, restart):
        """Override in subclasses to do things when the frontend shuts down the
        kernel.
        """
        return {'status': 'ok', 'restart': restart}

    #---------------------------------------------------------------------------
    # Engine methods
    #---------------------------------------------------------------------------

    def apply_request(self, stream, ident, parent):
        try:
            content = parent[u'content']
            bufs = parent[u'buffers']
            msg_id = parent['header']['msg_id']
        except:
            self.log.error("Got bad msg: %s", parent, exc_info=True)
            return

        self._publish_status(u'busy', parent)

        # Set the parent message of the display hook and out streams.
        self.set_parent(ident, parent)

        md = self._make_metadata(parent['metadata'])

        reply_content, result_buf = self.do_apply(content, bufs, msg_id, md)

        # put 'ok'/'error' status in header, for scheduler introspection:
        md['status'] = reply_content['status']

        # flush i/o
        sys.stdout.flush()
        sys.stderr.flush()
        
        self.session.send(stream, u'apply_reply', reply_content,
                    parent=parent, ident=ident,buffers=result_buf, metadata=md)

        self._publish_status(u'idle', parent)

    def do_apply(self, content, bufs, msg_id, reply_metadata):
        """Override in subclasses to support the IPython parallel framework.
        """
        raise NotImplementedError

    #---------------------------------------------------------------------------
    # Control messages
    #---------------------------------------------------------------------------

    def abort_request(self, stream, ident, parent):
        """abort a specifig msg by id"""
        msg_ids = parent['content'].get('msg_ids', None)
        if isinstance(msg_ids, string_types):
            msg_ids = [msg_ids]
        if not msg_ids:
            self.abort_queues()
        for mid in msg_ids:
            self.aborted.add(str(mid))

        content = dict(status='ok')
        reply_msg = self.session.send(stream, 'abort_reply', content=content,
                parent=parent, ident=ident)
        self.log.debug("%s", reply_msg)

    def clear_request(self, stream, idents, parent):
        """Clear our namespace."""
        content = self.do_clear()
        self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
                content = content)

    def do_clear(self):
        """Override in subclasses to clear the namespace
        
        This is only required for IPython.parallel.
        """
        raise NotImplementedError

    #---------------------------------------------------------------------------
    # Protected interface
    #---------------------------------------------------------------------------

    def _topic(self, topic):
        """prefixed topic for IOPub messages"""
        if self.int_id >= 0:
            base = "engine.%i" % self.int_id
        else:
            base = "kernel.%s" % self.ident
        
        return py3compat.cast_bytes("%s.%s" % (base, topic))
    
    def _abort_queues(self):
        for stream in self.shell_streams:
            if stream:
                self._abort_queue(stream)

    def _abort_queue(self, stream):
        poller = zmq.Poller()
        poller.register(stream.socket, zmq.POLLIN)
        while True:
            idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
            if msg is None:
                return

            self.log.info("Aborting:")
            self.log.info("%s", msg)
            msg_type = msg['header']['msg_type']
            reply_type = msg_type.split('_')[0] + '_reply'

            status = {'status' : 'aborted'}
            md = {'engine' : self.ident}
            md.update(status)
            reply_msg = self.session.send(stream, reply_type, metadata=md,
                        content=status, parent=msg, ident=idents)
            self.log.debug("%s", reply_msg)
            # We need to wait a bit for requests to come in. This can probably
            # be set shorter for true asynchronous clients.
            poller.poll(50)


    def _no_raw_input(self):
        """Raise StdinNotImplentedError if active frontend doesn't support
        stdin."""
        raise StdinNotImplementedError("raw_input was called, but this "
                                       "frontend does not support stdin.") 
    
    def getpass(self, prompt=''):
        """Forward getpass to frontends
        
        Raises
        ------
        StdinNotImplentedError if active frontend doesn't support stdin.
        """
        if not self._allow_stdin:
            raise StdinNotImplementedError(
                "getpass was called, but this frontend does not support input requests."
            )
        return self._input_request(prompt,
            self._parent_ident,
            self._parent_header,
            password=True,
        )
    
    def raw_input(self, prompt=''):
        """Forward raw_input to frontends
        
        Raises
        ------
        StdinNotImplentedError if active frontend doesn't support stdin.
        """
        if not self._allow_stdin:
            raise StdinNotImplementedError(
                "raw_input was called, but this frontend does not support input requests."
            )
        return self._input_request(prompt,
            self._parent_ident,
            self._parent_header,
            password=False,
        )
    
    def _input_request(self, prompt, ident, parent, password=False):
        # Flush output before making the request.
        sys.stderr.flush()
        sys.stdout.flush()
        # flush the stdin socket, to purge stale replies
        while True:
            try:
                self.stdin_socket.recv_multipart(zmq.NOBLOCK)
            except zmq.ZMQError as e:
                if e.errno == zmq.EAGAIN:
                    break
                else:
                    raise
        
        # Send the input request.
        content = json_clean(dict(prompt=prompt, password=password))
        self.session.send(self.stdin_socket, u'input_request', content, parent,
                          ident=ident)

        # Await a response.
        while True:
            try:
                ident, reply = self.session.recv(self.stdin_socket, 0)
            except Exception:
                self.log.warn("Invalid Message:", exc_info=True)
            except KeyboardInterrupt:
                # re-raise KeyboardInterrupt, to truncate traceback
                raise KeyboardInterrupt
            else:
                break
        try:
            value = py3compat.unicode_to_str(reply['content']['value'])
        except:
            self.log.error("Bad input_reply: %s", parent)
            value = ''
        if value == '\x04':
            # EOF
            raise EOFError
        return value

    def _at_shutdown(self):
        """Actions taken at shutdown by the kernel, called by python's atexit.
        """
        # io.rprint("Kernel at_shutdown") # dbg
        if self._shutdown_message is not None:
            self.session.send(self.iopub_socket, self._shutdown_message, ident=self._topic('shutdown'))
            self.log.debug("%s", self._shutdown_message)
        [ s.flush(zmq.POLLOUT) for s in self.shell_streams ]
예제 #22
0
class HeartMonitor(LoggingConfigurable):
    """A basic HeartMonitor class
    pingstream: a PUB stream
    pongstream: an ROUTER stream
    period: the period of the heartbeat in milliseconds"""

    period = Integer(
        3000,
        config=True,
        help='The frequency at which the Hub pings the engines for heartbeats '
        '(in ms)',
    )
    max_heartmonitor_misses = Integer(
        10,
        config=True,
        help=
        'Allowed consecutive missed pings from controller Hub to engine before unregistering.',
    )

    pingstream = Instance('zmq.eventloop.zmqstream.ZMQStream')
    pongstream = Instance('zmq.eventloop.zmqstream.ZMQStream')
    loop = Instance('zmq.eventloop.ioloop.IOLoop')

    def _loop_default(self):
        return ioloop.IOLoop.instance()

    # not settable:
    hearts = Set()
    responses = Set()
    on_probation = Dict()
    last_ping = CFloat(0)
    _new_handlers = Set()
    _failure_handlers = Set()
    lifetime = CFloat(0)
    tic = CFloat(0)

    def __init__(self, **kwargs):
        super(HeartMonitor, self).__init__(**kwargs)

        self.pongstream.on_recv(self.handle_pong)

    def start(self):
        self.tic = time.time()
        self.caller = ioloop.PeriodicCallback(self.beat, self.period,
                                              self.loop)
        self.caller.start()

    def add_new_heart_handler(self, handler):
        """add a new handler for new hearts"""
        self.log.debug("heartbeat::new_heart_handler: %s", handler)
        self._new_handlers.add(handler)

    def add_heart_failure_handler(self, handler):
        """add a new handler for heart failure"""
        self.log.debug("heartbeat::new heart failure handler: %s", handler)
        self._failure_handlers.add(handler)

    def beat(self):
        self.pongstream.flush()
        self.last_ping = self.lifetime

        toc = time.time()
        self.lifetime += toc - self.tic
        self.tic = toc
        self.log.debug("heartbeat::sending %s", self.lifetime)
        goodhearts = self.hearts.intersection(self.responses)
        missed_beats = self.hearts.difference(goodhearts)
        newhearts = self.responses.difference(goodhearts)
        for heart in newhearts:
            self.handle_new_heart(heart)
        heartfailures, on_probation = self._check_missed(
            missed_beats, self.on_probation, self.hearts)
        for failure in heartfailures:
            self.handle_heart_failure(failure)
        self.on_probation = on_probation
        self.responses = set()
        #print self.on_probation, self.hearts
        # self.log.debug("heartbeat::beat %.3f, %i beating hearts", self.lifetime, len(self.hearts))
        self.pingstream.send(str_to_bytes(str(self.lifetime)))
        # flush stream to force immediate socket send
        self.pingstream.flush()

    def _check_missed(self, missed_beats, on_probation, hearts):
        """Update heartbeats on probation, identifying any that have too many misses.
        """
        failures = []
        new_probation = {}
        for cur_heart in (b for b in missed_beats if b in hearts):
            miss_count = on_probation.get(cur_heart, 0) + 1
            self.log.info("heartbeat::missed %s : %s" %
                          (cur_heart, miss_count))
            if miss_count > self.max_heartmonitor_misses:
                failures.append(cur_heart)
            else:
                new_probation[cur_heart] = miss_count
        return failures, new_probation

    def handle_new_heart(self, heart):
        if self._new_handlers:
            for handler in self._new_handlers:
                handler(heart)
        else:
            self.log.info("heartbeat::yay, got new heart %s!", heart)
        self.hearts.add(heart)

    def handle_heart_failure(self, heart):
        if self._failure_handlers:
            for handler in self._failure_handlers:
                try:
                    handler(heart)
                except Exception as e:
                    self.log.error("heartbeat::Bad Handler! %s",
                                   handler,
                                   exc_info=True)
                    pass
        else:
            self.log.info("heartbeat::Heart %s failed :(", heart)
        self.hearts.remove(heart)

    @log_errors
    def handle_pong(self, msg):
        "a heart just beat"
        current = str_to_bytes(str(self.lifetime))
        last = str_to_bytes(str(self.last_ping))
        if msg[1] == current:
            delta = time.time() - self.tic
            # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
            self.responses.add(msg[0])
        elif msg[1] == last:
            delta = time.time() - self.tic + (self.lifetime - self.last_ping)
            self.log.warn(
                "heartbeat::heart %r missed a beat, and took %.2f ms to respond",
                msg[0], 1000 * delta)
            self.responses.add(msg[0])
        else:
            self.log.warn(
                "heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)",
                msg[1], self.lifetime)
예제 #23
0
class KernelSpecManager(Configurable):
    ipython_dir = Unicode()

    def _ipython_dir_default(self):
        return get_ipython_dir()

    user_kernel_dir = Unicode()

    def _user_kernel_dir_default(self):
        return pjoin(self.ipython_dir, 'kernels')

    @property
    def env_kernel_dir(self):
        return pjoin(sys.prefix, 'share', 'jupyter', 'kernels')

    whitelist = Set(config=True,
                    help="""Whitelist of allowed kernel names.

        By default, all installed kernels are allowed.
        """)
    kernel_dirs = List(
        help=
        "List of kernel directories to search. Later ones take priority over earlier."
    )

    def _kernel_dirs_default(self):
        dirs = SYSTEM_KERNEL_DIRS[:]
        if self.env_kernel_dir not in dirs:
            dirs.append(self.env_kernel_dir)
        dirs.append(self.user_kernel_dir)
        return dirs

    def find_kernel_specs(self):
        """Returns a dict mapping kernel names to resource directories."""
        d = {}
        for kernel_dir in self.kernel_dirs:
            d.update(_list_kernels_in(kernel_dir))

        if self.whitelist:
            # filter if there's a whitelist
            d = {
                name: spec
                for name, spec in d.items() if name in self.whitelist
            }
        return d
        # TODO: Caching?

    def get_kernel_spec(self, kernel_name):
        """Returns a :class:`KernelSpec` instance for the given kernel_name.

        Raises :exc:`NoSuchKernel` if the given kernel name is not found.
        """
        d = self.find_kernel_specs()
        try:
            resource_dir = d[kernel_name.lower()]
        except KeyError:
            raise NoSuchKernel(kernel_name)
        return KernelSpec.from_resource_dir(resource_dir)

    def _get_destination_dir(self, kernel_name, user=False):
        if user:
            return os.path.join(self.user_kernel_dir, kernel_name)
        else:
            if SYSTEM_KERNEL_DIRS:
                return os.path.join(SYSTEM_KERNEL_DIRS[-1], kernel_name)
            else:
                raise EnvironmentError(
                    "No system kernel directory is available")

    def install_kernel_spec(self,
                            source_dir,
                            kernel_name=None,
                            user=False,
                            replace=False):
        """Install a kernel spec by copying its directory.

        If ``kernel_name`` is not given, the basename of ``source_dir`` will
        be used.

        If ``user`` is False, it will attempt to install into the systemwide
        kernel registry. If the process does not have appropriate permissions,
        an :exc:`OSError` will be raised.

        If ``replace`` is True, this will replace an existing kernel of the same
        name. Otherwise, if the destination already exists, an :exc:`OSError`
        will be raised.
        """
        if not kernel_name:
            kernel_name = os.path.basename(source_dir)
        kernel_name = kernel_name.lower()

        destination = self._get_destination_dir(kernel_name, user=user)

        if replace and os.path.isdir(destination):
            shutil.rmtree(destination)

        shutil.copytree(source_dir, destination)

    def install_native_kernel_spec(self, user=False):
        """DEPRECATED: Use ipython_kernel.kenelspec.install"""
        warnings.warn("install_native_kernel_spec is deprecated."
                      " Use ipython_kernel.kernelspec import install.")
        from ipython_kernel.kernelspec import install
        install(self, user=user)
예제 #24
0
class InlineBackend(InlineBackendConfig):
    """An object to store configuration of the inline backend."""
    def _config_changed(self, name, old, new):
        # warn on change of renamed config section
        if new.InlineBackendConfig != old.InlineBackendConfig:
            warn("InlineBackendConfig has been renamed to InlineBackend")
        super(InlineBackend, self)._config_changed(name, old, new)

    # The typical default figure size is too large for inline use,
    # so we shrink the figure size to 6x4, and tweak fonts to
    # make that fit.
    rc = Dict(
        {
            'figure.figsize': (6.0, 4.0),
            # play nicely with white background in the Qt and notebook frontend
            'figure.facecolor': (1, 1, 1, 0),
            'figure.edgecolor': (1, 1, 1, 0),
            # 12pt labels get cutoff on 6x4 logplots, so use 10pt.
            'font.size': 10,
            # 72 dpi matches SVG/qtconsole
            # this only affects PNG export, as SVG has no dpi setting
            'savefig.dpi': 72,
            # 10pt still needs a little more room on the xlabel:
            'figure.subplot.bottom': .125
        },
        config=True,
        help="""Subset of matplotlib rcParams that should be different for the
        inline backend.""")

    figure_formats = Set({'png'},
                         config=True,
                         help="""A set of figure formats to enable: 'png', 
                          'retina', 'jpeg', 'svg', 'pdf'.""")

    def _figure_formats_changed(self, name, old, new):
        from IPython.core.pylabtools import select_figure_formats
        if 'jpg' in new or 'jpeg' in new:
            if not pil_available():
                raise TraitError("Requires PIL/Pillow for JPG figures")
        if self.shell is None:
            return
        else:
            select_figure_formats(self.shell, new)

    figure_format = Unicode(config=True,
                            help="""The figure format to enable (deprecated
                                         use `figure_formats` instead)""")

    def _figure_format_changed(self, name, old, new):
        if new:
            self.figure_formats = {new}

    quality = Int(
        default_value=90,
        config=True,
        help="Quality of compression [10-100], currently for lossy JPEG only.")

    def _quality_changed(self, name, old, new):
        if new < 10 or new > 100:
            raise TraitError("figure JPEG quality must be in [10-100] range.")

    close_figures = Bool(True,
                         config=True,
                         help="""Close all figures at the end of each cell.
        
        When True, ensures that each cell starts with no active figures, but it
        also means that one must keep track of references in order to edit or
        redraw figures in subsequent cells. This mode is ideal for the notebook,
        where residual plots from other cells might be surprising.
        
        When False, one must call figure() to create new figures. This means
        that gcf() and getfigs() can reference figures created in other cells,
        and the active figure can continue to be edited with pylab/pyplot
        methods that reference the current active figure. This mode facilitates
        iterative editing of figures, and behaves most consistently with
        other matplotlib backends, but figure barriers between cells must
        be explicit.
        """)

    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')