Esempio n. 1
0
class Hub(LoggingFactory):
    """The IPython Controller Hub with 0MQ connections
    
    Parameters
    ==========
    loop: zmq IOLoop instance
    session: StreamSession object
    <removed> context: zmq context for creating new connections (?)
    queue: ZMQStream for monitoring the command queue (SUB)
    query: ZMQStream for engine registration and client queries requests (XREP)
    heartbeat: HeartMonitor object checking the pulse of the engines
    notifier: ZMQStream for broadcasting engine registration changes (PUB)
    db: connection to db for out of memory logging of commands
                NotImplemented
    engine_info: dict of zmq connection information for engines to connect
                to the queues.
    client_info: dict of zmq connection information for engines to connect
                to the queues.
    """
    # internal data structures:
    ids = Set()  # engine IDs
    keytable = Dict()
    by_ident = Dict()
    engines = Dict()
    clients = Dict()
    hearts = Dict()
    pending = Set()
    queues = Dict()  # pending msg_ids keyed by engine_id
    tasks = Dict()  # pending msg_ids submitted as tasks, keyed by client_id
    completed = Dict()  # completed msg_ids keyed by engine_id
    all_completed = Set()  # completed msg_ids keyed by engine_id
    dead_engines = Set()  # completed msg_ids keyed by engine_id
    unassigned = Set()  # set of task msg_ds not yet assigned a destination
    incoming_registrations = Dict()
    registration_timeout = Int()
    _idcounter = Int(0)

    # objects from constructor:
    loop = Instance(ioloop.IOLoop)
    query = Instance(ZMQStream)
    monitor = Instance(ZMQStream)
    heartmonitor = Instance(HeartMonitor)
    notifier = Instance(ZMQStream)
    db = Instance(object)
    client_info = Dict()
    engine_info = Dict()

    def __init__(self, **kwargs):
        """
        # universal:
        loop: IOLoop for creating future connections
        session: streamsession for sending serialized data
        # engine:
        queue: ZMQStream for monitoring queue messages
        query: ZMQStream for engine+client registration and client requests
        heartbeat: HeartMonitor object for tracking engines
        # extra:
        db: ZMQStream for db connection (NotImplemented)
        engine_info: zmq address/protocol dict for engine connections
        client_info: zmq address/protocol dict for client connections
        """

        super(Hub, self).__init__(**kwargs)
        self.registration_timeout = max(5000, 2 * self.heartmonitor.period)

        # validate connection dicts:
        for k, v in self.client_info.iteritems():
            if k == 'task':
                validate_url_container(v[1])
            else:
                validate_url_container(v)
        # validate_url_container(self.client_info)
        validate_url_container(self.engine_info)

        # register our callbacks
        self.query.on_recv(self.dispatch_query)
        self.monitor.on_recv(self.dispatch_monitor_traffic)

        self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
        self.heartmonitor.add_new_heart_handler(self.handle_new_heart)

        self.monitor_handlers = {
            'in': self.save_queue_request,
            'out': self.save_queue_result,
            'intask': self.save_task_request,
            'outtask': self.save_task_result,
            'tracktask': self.save_task_destination,
            'incontrol': _passer,
            'outcontrol': _passer,
            'iopub': self.save_iopub_message,
        }

        self.query_handlers = {
            'queue_request': self.queue_status,
            'result_request': self.get_results,
            'purge_request': self.purge_results,
            'load_request': self.check_load,
            'resubmit_request': self.resubmit_task,
            'shutdown_request': self.shutdown_request,
            'registration_request': self.register_engine,
            'unregistration_request': self.unregister_engine,
            'connection_request': self.connection_request,
        }

        self.log.info("hub::created hub")

    @property
    def _next_id(self):
        """gemerate a new ID.
        
        No longer reuse old ids, just count from 0."""
        newid = self._idcounter
        self._idcounter += 1
        return newid
        # newid = 0
        # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
        # # print newid, self.ids, self.incoming_registrations
        # while newid in self.ids or newid in incoming:
        #     newid += 1
        # return newid

    #-----------------------------------------------------------------------------
    # message validation
    #-----------------------------------------------------------------------------

    def _validate_targets(self, targets):
        """turn any valid targets argument into a list of integer ids"""
        if targets is None:
            # default to all
            targets = self.ids

        if isinstance(targets, (int, str, unicode)):
            # only one target specified
            targets = [targets]
        _targets = []
        for t in targets:
            # map raw identities to ids
            if isinstance(t, (str, unicode)):
                t = self.by_ident.get(t, t)
            _targets.append(t)
        targets = _targets
        bad_targets = [t for t in targets if t not in self.ids]
        if bad_targets:
            raise IndexError("No Such Engine: %r" % bad_targets)
        if not targets:
            raise IndexError("No Engines Registered")
        return targets

    #-----------------------------------------------------------------------------
    # dispatch methods (1 per stream)
    #-----------------------------------------------------------------------------

    # def dispatch_registration_request(self, msg):
    #     """"""
    #     self.log.debug("registration::dispatch_register_request(%s)"%msg)
    #     idents,msg = self.session.feed_identities(msg)
    #     if not idents:
    #         self.log.error("Bad Query Message: %s"%msg, exc_info=True)
    #         return
    #     try:
    #         msg = self.session.unpack_message(msg,content=True)
    #     except:
    #         self.log.error("registration::got bad registration message: %s"%msg, exc_info=True)
    #         return
    #
    #     msg_type = msg['msg_type']
    #     content = msg['content']
    #
    #     handler = self.query_handlers.get(msg_type, None)
    #     if handler is None:
    #         self.log.error("registration::got bad registration message: %s"%msg)
    #     else:
    #         handler(idents, msg)

    def dispatch_monitor_traffic(self, msg):
        """all ME and Task queue messages come through here, as well as
        IOPub traffic."""
        self.log.debug("monitor traffic: %s" % msg[:2])
        switch = msg[0]
        idents, msg = self.session.feed_identities(msg[1:])
        if not idents:
            self.log.error("Bad Monitor Message: %s" % msg)
            return
        handler = self.monitor_handlers.get(switch, None)
        if handler is not None:
            handler(idents, msg)
        else:
            self.log.error("Invalid monitor topic: %s" % switch)

    def dispatch_query(self, msg):
        """Route registration requests and queries from clients."""
        idents, msg = self.session.feed_identities(msg)
        if not idents:
            self.log.error("Bad Query Message: %s" % msg)
            return
        client_id = idents[0]
        try:
            msg = self.session.unpack_message(msg, content=True)
        except:
            content = error.wrap_exception()
            self.log.error("Bad Query Message: %s" % msg, exc_info=True)
            self.session.send(self.query,
                              "hub_error",
                              ident=client_id,
                              content=content)
            return

        # print client_id, header, parent, content
        #switch on message type:
        msg_type = msg['msg_type']
        self.log.info("client::client %s requested %s" % (client_id, msg_type))
        handler = self.query_handlers.get(msg_type, None)
        try:
            assert handler is not None, "Bad Message Type: %s" % msg_type
        except:
            content = error.wrap_exception()
            self.log.error("Bad Message Type: %s" % msg_type, exc_info=True)
            self.session.send(self.query,
                              "hub_error",
                              ident=client_id,
                              content=content)
            return
        else:
            handler(idents, msg)

    def dispatch_db(self, msg):
        """"""
        raise NotImplementedError

    #---------------------------------------------------------------------------
    # handler methods (1 per event)
    #---------------------------------------------------------------------------

    #----------------------- Heartbeat --------------------------------------

    def handle_new_heart(self, heart):
        """handler to attach to heartbeater.
        Called when a new heart starts to beat.
        Triggers completion of registration."""
        self.log.debug("heartbeat::handle_new_heart(%r)" % heart)
        if heart not in self.incoming_registrations:
            self.log.info("heartbeat::ignoring new heart: %r" % heart)
        else:
            self.finish_registration(heart)

    def handle_heart_failure(self, heart):
        """handler to attach to heartbeater.
        called when a previously registered heart fails to respond to beat request.
        triggers unregistration"""
        self.log.debug("heartbeat::handle_heart_failure(%r)" % heart)
        eid = self.hearts.get(heart, None)
        queue = self.engines[eid].queue
        if eid is None:
            self.log.info("heartbeat::ignoring heart failure %r" % heart)
        else:
            self.unregister_engine(heart,
                                   dict(content=dict(id=eid, queue=queue)))

    #----------------------- MUX Queue Traffic ------------------------------

    def save_queue_request(self, idents, msg):
        if len(idents) < 2:
            self.log.error("invalid identity prefix: %s" % idents)
            return
        queue_id, client_id = idents[:2]
        try:
            msg = self.session.unpack_message(msg, content=False)
        except:
            self.log.error("queue::client %r sent invalid message to %r: %s" %
                           (client_id, queue_id, msg),
                           exc_info=True)
            return

        eid = self.by_ident.get(queue_id, None)
        if eid is None:
            self.log.error("queue::target %r not registered" % queue_id)
            self.log.debug("queue::    valid are: %s" % (self.by_ident.keys()))
            return

        header = msg['header']
        msg_id = header['msg_id']
        record = init_record(msg)
        record['engine_uuid'] = queue_id
        record['client_uuid'] = client_id
        record['queue'] = 'mux'

        try:
            # it's posible iopub arrived first:
            existing = self.db.get_record(msg_id)
            for key, evalue in existing.iteritems():
                rvalue = record[key]
                if evalue and rvalue and evalue != rvalue:
                    self.log.error(
                        "conflicting initial state for record: %s:%s <> %s" %
                        (msg_id, rvalue, evalue))
                elif evalue and not rvalue:
                    record[key] = evalue
            self.db.update_record(msg_id, record)
        except KeyError:
            self.db.add_record(msg_id, record)

        self.pending.add(msg_id)
        self.queues[eid].append(msg_id)

    def save_queue_result(self, idents, msg):
        if len(idents) < 2:
            self.log.error("invalid identity prefix: %s" % idents)
            return

        client_id, queue_id = idents[:2]
        try:
            msg = self.session.unpack_message(msg, content=False)
        except:
            self.log.error("queue::engine %r sent invalid message to %r: %s" %
                           (queue_id, client_id, msg),
                           exc_info=True)
            return

        eid = self.by_ident.get(queue_id, None)
        if eid is None:
            self.log.error("queue::unknown engine %r is sending a reply: " %
                           queue_id)
            # self.log.debug("queue::       %s"%msg[2:])
            return

        parent = msg['parent_header']
        if not parent:
            return
        msg_id = parent['msg_id']
        if msg_id in self.pending:
            self.pending.remove(msg_id)
            self.all_completed.add(msg_id)
            self.queues[eid].remove(msg_id)
            self.completed[eid].append(msg_id)
        elif msg_id not in self.all_completed:
            # it could be a result from a dead engine that died before delivering the
            # result
            self.log.warn("queue:: unknown msg finished %s" % msg_id)
            return
        # update record anyway, because the unregistration could have been premature
        rheader = msg['header']
        completed = datetime.strptime(rheader['date'], ISO8601)
        started = rheader.get('started', None)
        if started is not None:
            started = datetime.strptime(started, ISO8601)
        result = {
            'result_header': rheader,
            'result_content': msg['content'],
            'started': started,
            'completed': completed
        }

        result['result_buffers'] = msg['buffers']
        self.db.update_record(msg_id, result)

    #--------------------- Task Queue Traffic ------------------------------

    def save_task_request(self, idents, msg):
        """Save the submission of a task."""
        client_id = idents[0]

        try:
            msg = self.session.unpack_message(msg, content=False)
        except:
            self.log.error("task::client %r sent invalid task message: %s" %
                           (client_id, msg),
                           exc_info=True)
            return
        record = init_record(msg)

        record['client_uuid'] = client_id
        record['queue'] = 'task'
        header = msg['header']
        msg_id = header['msg_id']
        self.pending.add(msg_id)
        self.unassigned.add(msg_id)
        try:
            # it's posible iopub arrived first:
            existing = self.db.get_record(msg_id)
            for key, evalue in existing.iteritems():
                rvalue = record[key]
                if evalue and rvalue and evalue != rvalue:
                    self.log.error(
                        "conflicting initial state for record: %s:%s <> %s" %
                        (msg_id, rvalue, evalue))
                elif evalue and not rvalue:
                    record[key] = evalue
            self.db.update_record(msg_id, record)
        except KeyError:
            self.db.add_record(msg_id, record)

    def save_task_result(self, idents, msg):
        """save the result of a completed task."""
        client_id = idents[0]
        try:
            msg = self.session.unpack_message(msg, content=False)
        except:
            self.log.error("task::invalid task result message send to %r: %s" %
                           (client_id, msg),
                           exc_info=True)
            raise
            return

        parent = msg['parent_header']
        if not parent:
            # print msg
            self.log.warn("Task %r had no parent!" % msg)
            return
        msg_id = parent['msg_id']
        if msg_id in self.unassigned:
            self.unassigned.remove(msg_id)

        header = msg['header']
        engine_uuid = header.get('engine', None)
        eid = self.by_ident.get(engine_uuid, None)

        if msg_id in self.pending:
            self.pending.remove(msg_id)
            self.all_completed.add(msg_id)
            if eid is not None:
                self.completed[eid].append(msg_id)
                if msg_id in self.tasks[eid]:
                    self.tasks[eid].remove(msg_id)
            completed = datetime.strptime(header['date'], ISO8601)
            started = header.get('started', None)
            if started is not None:
                started = datetime.strptime(started, ISO8601)
            result = {
                'result_header': header,
                'result_content': msg['content'],
                'started': started,
                'completed': completed,
                'engine_uuid': engine_uuid
            }

            result['result_buffers'] = msg['buffers']
            self.db.update_record(msg_id, result)

        else:
            self.log.debug("task::unknown task %s finished" % msg_id)

    def save_task_destination(self, idents, msg):
        try:
            msg = self.session.unpack_message(msg, content=True)
        except:
            self.log.error("task::invalid task tracking message",
                           exc_info=True)
            return
        content = msg['content']
        # print (content)
        msg_id = content['msg_id']
        engine_uuid = content['engine_id']
        eid = self.by_ident[engine_uuid]

        self.log.info("task::task %s arrived on %s" % (msg_id, eid))
        if msg_id in self.unassigned:
            self.unassigned.remove(msg_id)
        # else:
        #     self.log.debug("task::task %s not listed as MIA?!"%(msg_id))

        self.tasks[eid].append(msg_id)
        # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
        self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))

    def mia_task_request(self, idents, msg):
        raise NotImplementedError
        client_id = idents[0]
        # content = dict(mia=self.mia,status='ok')
        # self.session.send('mia_reply', content=content, idents=client_id)

    #--------------------- IOPub Traffic ------------------------------

    def save_iopub_message(self, topics, msg):
        """save an iopub message into the db"""
        # print (topics)
        try:
            msg = self.session.unpack_message(msg, content=True)
        except:
            self.log.error("iopub::invalid IOPub message", exc_info=True)
            return

        parent = msg['parent_header']
        if not parent:
            self.log.error("iopub::invalid IOPub message: %s" % msg)
            return
        msg_id = parent['msg_id']
        msg_type = msg['msg_type']
        content = msg['content']

        # ensure msg_id is in db
        try:
            rec = self.db.get_record(msg_id)
        except KeyError:
            rec = empty_record()
            rec['msg_id'] = msg_id
            self.db.add_record(msg_id, rec)
        # stream
        d = {}
        if msg_type == 'stream':
            name = content['name']
            s = rec[name] or ''
            d[name] = s + content['data']

        elif msg_type == 'pyerr':
            d['pyerr'] = content
        elif msg_type == 'pyin':
            d['pyin'] = content['code']
        else:
            d[msg_type] = content.get('data', '')

        self.db.update_record(msg_id, d)

    #-------------------------------------------------------------------------
    # Registration requests
    #-------------------------------------------------------------------------

    def connection_request(self, client_id, msg):
        """Reply with connection addresses for clients."""
        self.log.info("client::client %s connected" % client_id)
        content = dict(status='ok')
        content.update(self.client_info)
        jsonable = {}
        for k, v in self.keytable.iteritems():
            if v not in self.dead_engines:
                jsonable[str(k)] = v
        content['engines'] = jsonable
        self.session.send(self.query,
                          'connection_reply',
                          content,
                          parent=msg,
                          ident=client_id)

    def register_engine(self, reg, msg):
        """Register a new engine."""
        content = msg['content']
        try:
            queue = content['queue']
        except KeyError:
            self.log.error("registration::queue not specified", exc_info=True)
            return
        heart = content.get('heartbeat', None)
        """register a new engine, and create the socket(s) necessary"""
        eid = self._next_id
        # print (eid, queue, reg, heart)

        self.log.debug("registration::register_engine(%i, %r, %r, %r)" %
                       (eid, queue, reg, heart))

        content = dict(id=eid, status='ok')
        content.update(self.engine_info)
        # check if requesting available IDs:
        if queue in self.by_ident:
            try:
                raise KeyError("queue_id %r in use" % queue)
            except:
                content = error.wrap_exception()
                self.log.error("queue_id %r in use" % queue, exc_info=True)
        elif heart in self.hearts:  # need to check unique hearts?
            try:
                raise KeyError("heart_id %r in use" % heart)
            except:
                self.log.error("heart_id %r in use" % heart, exc_info=True)
                content = error.wrap_exception()
        else:
            for h, pack in self.incoming_registrations.iteritems():
                if heart == h:
                    try:
                        raise KeyError("heart_id %r in use" % heart)
                    except:
                        self.log.error("heart_id %r in use" % heart,
                                       exc_info=True)
                        content = error.wrap_exception()
                    break
                elif queue == pack[1]:
                    try:
                        raise KeyError("queue_id %r in use" % queue)
                    except:
                        self.log.error("queue_id %r in use" % queue,
                                       exc_info=True)
                        content = error.wrap_exception()
                    break

        msg = self.session.send(self.query,
                                "registration_reply",
                                content=content,
                                ident=reg)

        if content['status'] == 'ok':
            if heart in self.heartmonitor.hearts:
                # already beating
                self.incoming_registrations[heart] = (eid, queue, reg[0], None)
                self.finish_registration(heart)
            else:
                purge = lambda: self._purge_stalled_registration(heart)
                dc = ioloop.DelayedCallback(purge, self.registration_timeout,
                                            self.loop)
                dc.start()
                self.incoming_registrations[heart] = (eid, queue, reg[0], dc)
        else:
            self.log.error("registration::registration %i failed: %s" %
                           (eid, content['evalue']))
        return eid

    def unregister_engine(self, ident, msg):
        """Unregister an engine that explicitly requested to leave."""
        try:
            eid = msg['content']['id']
        except:
            self.log.error(
                "registration::bad engine id for unregistration: %s" % ident,
                exc_info=True)
            return
        self.log.info("registration::unregister_engine(%s)" % eid)
        # print (eid)
        uuid = self.keytable[eid]
        content = dict(id=eid, queue=uuid)
        self.dead_engines.add(uuid)
        # self.ids.remove(eid)
        # uuid = self.keytable.pop(eid)
        #
        # ec = self.engines.pop(eid)
        # self.hearts.pop(ec.heartbeat)
        # self.by_ident.pop(ec.queue)
        # self.completed.pop(eid)
        handleit = lambda: self._handle_stranded_msgs(eid, uuid)
        dc = ioloop.DelayedCallback(handleit, self.registration_timeout,
                                    self.loop)
        dc.start()
        ############## TODO: HANDLE IT ################

        if self.notifier:
            self.session.send(self.notifier,
                              "unregistration_notification",
                              content=content)

    def _handle_stranded_msgs(self, eid, uuid):
        """Handle messages known to be on an engine when the engine unregisters.
        
        It is possible that this will fire prematurely - that is, an engine will
        go down after completing a result, and the client will be notified
        that the result failed and later receive the actual result.
        """

        outstanding = self.queues[eid]

        for msg_id in outstanding:
            self.pending.remove(msg_id)
            self.all_completed.add(msg_id)
            try:
                raise error.EngineError(
                    "Engine %r died while running task %r" % (eid, msg_id))
            except:
                content = error.wrap_exception()
            # build a fake header:
            header = {}
            header['engine'] = uuid
            header['date'] = datetime.now().strftime(ISO8601)
            rec = dict(result_content=content,
                       result_header=header,
                       result_buffers=[])
            rec['completed'] = header['date']
            rec['engine_uuid'] = uuid
            self.db.update_record(msg_id, rec)

    def finish_registration(self, heart):
        """Second half of engine registration, called after our HeartMonitor
        has received a beat from the Engine's Heart."""
        try:
            (eid, queue, reg, purge) = self.incoming_registrations.pop(heart)
        except KeyError:
            self.log.error(
                "registration::tried to finish nonexistant registration",
                exc_info=True)
            return
        self.log.info("registration::finished registering engine %i:%r" %
                      (eid, queue))
        if purge is not None:
            purge.stop()
        control = queue
        self.ids.add(eid)
        self.keytable[eid] = queue
        self.engines[eid] = EngineConnector(id=eid,
                                            queue=queue,
                                            registration=reg,
                                            control=control,
                                            heartbeat=heart)
        self.by_ident[queue] = eid
        self.queues[eid] = list()
        self.tasks[eid] = list()
        self.completed[eid] = list()
        self.hearts[heart] = eid
        content = dict(id=eid, queue=self.engines[eid].queue)
        if self.notifier:
            self.session.send(self.notifier,
                              "registration_notification",
                              content=content)
        self.log.info("engine::Engine Connected: %i" % eid)

    def _purge_stalled_registration(self, heart):
        if heart in self.incoming_registrations:
            eid = self.incoming_registrations.pop(heart)[0]
            self.log.info("registration::purging stalled registration: %i" %
                          eid)
        else:
            pass

    #-------------------------------------------------------------------------
    # Client Requests
    #-------------------------------------------------------------------------

    def shutdown_request(self, client_id, msg):
        """handle shutdown request."""
        self.session.send(self.query,
                          'shutdown_reply',
                          content={'status': 'ok'},
                          ident=client_id)
        # also notify other clients of shutdown
        self.session.send(self.notifier,
                          'shutdown_notice',
                          content={'status': 'ok'})
        dc = ioloop.DelayedCallback(lambda: self._shutdown(), 1000, self.loop)
        dc.start()

    def _shutdown(self):
        self.log.info("hub::hub shutting down.")
        time.sleep(0.1)
        sys.exit(0)

    def check_load(self, client_id, msg):
        content = msg['content']
        try:
            targets = content['targets']
            targets = self._validate_targets(targets)
        except:
            content = error.wrap_exception()
            self.session.send(self.query,
                              "hub_error",
                              content=content,
                              ident=client_id)
            return

        content = dict(status='ok')
        # loads = {}
        for t in targets:
            content[bytes(t)] = len(self.queues[t]) + len(self.tasks[t])
        self.session.send(self.query,
                          "load_reply",
                          content=content,
                          ident=client_id)

    def queue_status(self, client_id, msg):
        """Return the Queue status of one or more targets.
        if verbose: return the msg_ids
        else: return len of each type.
        keys: queue (pending MUX jobs)
            tasks (pending Task jobs)
            completed (finished jobs from both queues)"""
        content = msg['content']
        targets = content['targets']
        try:
            targets = self._validate_targets(targets)
        except:
            content = error.wrap_exception()
            self.session.send(self.query,
                              "hub_error",
                              content=content,
                              ident=client_id)
            return
        verbose = content.get('verbose', False)
        content = dict(status='ok')
        for t in targets:
            queue = self.queues[t]
            completed = self.completed[t]
            tasks = self.tasks[t]
            if not verbose:
                queue = len(queue)
                completed = len(completed)
                tasks = len(tasks)
            content[bytes(t)] = {
                'queue': queue,
                'completed': completed,
                'tasks': tasks
            }
        content['unassigned'] = list(self.unassigned) if verbose else len(
            self.unassigned)

        self.session.send(self.query,
                          "queue_reply",
                          content=content,
                          ident=client_id)

    def purge_results(self, client_id, msg):
        """Purge results from memory. This method is more valuable before we move
        to a DB based message storage mechanism."""
        content = msg['content']
        msg_ids = content.get('msg_ids', [])
        reply = dict(status='ok')
        if msg_ids == 'all':
            self.db.drop_matching_records(dict(completed={'$ne': None}))
        else:
            for msg_id in msg_ids:
                if msg_id in self.all_completed:
                    self.db.drop_record(msg_id)
                else:
                    if msg_id in self.pending:
                        try:
                            raise IndexError("msg pending: %r" % msg_id)
                        except:
                            reply = error.wrap_exception()
                    else:
                        try:
                            raise IndexError("No such msg: %r" % msg_id)
                        except:
                            reply = error.wrap_exception()
                    break
            eids = content.get('engine_ids', [])
            for eid in eids:
                if eid not in self.engines:
                    try:
                        raise IndexError("No such engine: %i" % eid)
                    except:
                        reply = error.wrap_exception()
                    break
                msg_ids = self.completed.pop(eid)
                uid = self.engines[eid].queue
                self.db.drop_matching_records(
                    dict(engine_uuid=uid, completed={'$ne': None}))

        self.session.send(self.query,
                          'purge_reply',
                          content=reply,
                          ident=client_id)

    def resubmit_task(self, client_id, msg, buffers):
        """Resubmit a task."""
        raise NotImplementedError

    def get_results(self, client_id, msg):
        """Get the result of 1 or more messages."""
        content = msg['content']
        msg_ids = sorted(set(content['msg_ids']))
        statusonly = content.get('status_only', False)
        pending = []
        completed = []
        content = dict(status='ok')
        content['pending'] = pending
        content['completed'] = completed
        buffers = []
        if not statusonly:
            content['results'] = {}
            records = self.db.find_records(dict(msg_id={'$in': msg_ids}))
        for msg_id in msg_ids:
            if msg_id in self.pending:
                pending.append(msg_id)
            elif msg_id in self.all_completed:
                completed.append(msg_id)
                if not statusonly:
                    rec = records[msg_id]
                    io_dict = {}
                    for key in 'pyin pyout pyerr stdout stderr'.split():
                        io_dict[key] = rec[key]
                    content[msg_id] = {
                        'result_content': rec['result_content'],
                        'header': rec['header'],
                        'result_header': rec['result_header'],
                        'io': io_dict,
                    }
                    if rec['result_buffers']:
                        buffers.extend(map(str, rec['result_buffers']))
            else:
                try:
                    raise KeyError('No such message: ' + msg_id)
                except:
                    content = error.wrap_exception()
                break
        self.session.send(self.query,
                          "result_reply",
                          content=content,
                          parent=msg,
                          ident=client_id,
                          buffers=buffers)
Esempio n. 2
0
class IntTrait(HasTraits):

    value = Int(99)
Esempio n. 3
0
 class A(HasTraits):
     value = Int()
Esempio n. 4
0
 class C(HasTraits):
     c = Int(30)
Esempio n. 5
0
 class A(HasTraits):
     i = Int(config_key='VALUE1', other_thing='VALUE2')
     f = Float(config_key='VALUE3', other_thing='VALUE2')
     j = Int(0)
Esempio n. 6
0
class VispyWidget(DOMWidget):
    _view_name = Unicode("VispyView", sync=True)
    _view_module = Unicode('/nbextensions/vispy/webgl-backend.js', sync=True)

    #height/width of the widget is managed by IPython.
    #it's a string and can be anything valid in CSS.
    #here we only manage the size of the viewport.
    width = Int(sync=True)
    height = Int(sync=True)
    resizable = Bool(value=True, sync=True)

    def __init__(self, **kwargs):
        super(VispyWidget, self).__init__(**kwargs)
        self.on_msg(self.events_received)
        self.canvas = None
        self.canvas_backend = None
        self.gen_event = None

    def set_canvas(self, canvas):
        self.width, self.height = canvas._backend._default_size
        self.canvas = canvas
        self.canvas_backend = self.canvas._backend
        self.canvas_backend.set_widget(self)
        self.gen_event = self.canvas_backend._gen_event
        #setup the backend widget then.

    # In IPython < 4, these callbacks are given two arguments; in
    # IPython/jupyter 4, they take 3. events_received is variadic to
    # accommodate both cases.
    def events_received(self, _, msg, *args):
        if msg['msg_type'] == 'init':
            self.canvas_backend._reinit_widget()
        elif msg['msg_type'] == 'events':
            events = msg['contents']
            for ev in events:
                self.gen_event(ev)
        elif msg['msg_type'] == 'status':
            if msg['contents'] == 'removed':
                # Stop all timers associated to the widget.
                _stop_timers(self.canvas_backend._vispy_canvas)

    def send_glir_commands(self, commands):
        # TODO: check whether binary websocket is available (ipython >= 3)
        # Until IPython 3.0 is released, use base64.
        array_serialization = 'base64'
        # array_serialization = 'binary'
        if array_serialization == 'base64':
            msg = create_glir_message(commands, 'base64')
            msg['array_serialization'] = 'base64'
            self.send(msg)
        elif array_serialization == 'binary':
            msg = create_glir_message(commands, 'binary')
            msg['array_serialization'] = 'binary'
            # Remove the buffers from the JSON message: they will be sent
            # independently via binary WebSocket.
            buffers = msg.pop('buffers')
            self.comm.send({
                "method": "custom",
                "content": msg
            },
                           buffers=buffers)
Esempio n. 7
0
 class B(A):
     x = Int(20)
Esempio n. 8
0
class SqlMagic(Magics, Configurable):
    """Runs SQL statement on a database, specified by SQLAlchemy connect string.

    Provides the %%sql magic."""

    autolimit = Int(
        0,
        config=True,
        help="Automatically limit the size of the returned result sets")
    style = Unicode(
        'DEFAULT',
        config=True,
        help=
        "Set the table printing style to any of prettytable's defined styles (currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM)"
    )
    short_errors = Bool(
        True,
        config=True,
        help="Don't display the full traceback on SQL Programming Error")
    displaylimit = Int(
        0,
        config=True,
        help=
        "Automatic,ally limit the number of rows displayed (full result set is still stored)"
    )
    autopandas = Bool(
        False,
        config=True,
        help="Return Pandas DataFrames instead of regular result sets")
    feedback = Bool(True,
                    config=True,
                    help="Print number of rows affected by DML")

    def __init__(self, shell):
        Configurable.__init__(self, config=shell.config)
        Magics.__init__(self, shell=shell)

        # Add ourself to the list of module configurable via %config
        self.shell.configurables.append(self)

    @needs_local_scope
    @line_magic('sql')
    @cell_magic('sql')
    def execute(self, line, cell='', local_ns={}):
        """Runs SQL statement against a database, specified by SQLAlchemy connect string.

        If no database connection has been established, first word
        should be a SQLAlchemy connection string, or the user@db name
        of an established connection.

        Examples::

          %%sql postgresql://me:mypw@localhost/mydb
          SELECT * FROM mytable

          %%sql me@mydb
          DELETE FROM mytable

          %%sql
          DROP TABLE mytable

        SQLAlchemy connect string syntax examples:

          postgresql://me:mypw@localhost/mydb
          sqlite://
          mysql+pymysql://me:mypw@localhost/mydb

        """
        # save globals and locals so they can be referenced in bind vars
        user_ns = self.shell.user_ns
        user_ns.update(local_ns)

        parsed = sql.parse.parse('%s\n%s' % (line, cell))
        conn = sql.connection.Connection.get(parsed['connection'])
        try:
            result = sql.run.run(conn, parsed['sql'], self, user_ns)
            return result
        except (ProgrammingError, OperationalError) as e:
            # Sqlite apparently return all errors as OperationalError :/
            if self.short_errors:
                print(e)
            else:
                raise
Esempio n. 9
0
class HistoryManager(Configurable):
    """A class to organize all history-related functionality in one place.
    """
    # Public interface

    # An instance of the IPython shell we are attached to
    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
    # Lists to hold processed and raw history. These start with a blank entry
    # so that we can index them starting from 1
    input_hist_parsed = List([""])
    input_hist_raw = List([""])
    # A list of directories visited during session
    dir_hist = List()

    def _dir_hist_default(self):
        try:
            return [os.getcwd()]
        except OSError:
            return []

    # A dict of output history, keyed with ints from the shell's
    # execution count. If there are several outputs from one command,
    # only the last one is stored.
    output_hist = Dict()
    # Contains all outputs, in lists of reprs.
    output_hist_reprs = Instance(defaultdict, args=(list, ))

    # String holding the path to the history file
    hist_file = Unicode(config=True)

    # The SQLite database
    db = Instance(sqlite3.Connection)
    # The number of the current session in the history database
    session_number = Int()
    # Should we log output to the database? (default no)
    db_log_output = Bool(False, config=True)
    # Write to database every x commands (higher values save disk access & power)
    #  Values of 1 or less effectively disable caching.
    db_cache_size = Int(0, config=True)
    # The input and output caches
    db_input_cache = List()
    db_output_cache = List()

    # Private interface
    # Variables used to store the three last inputs from the user.  On each new
    # history update, we populate the user's namespace with these, shifted as
    # necessary.
    _i00 = Unicode(u'')
    _i = Unicode(u'')
    _ii = Unicode(u'')
    _iii = Unicode(u'')

    # A set with all forms of the exit command, so that we don't store them in
    # the history (it's annoying to rewind the first entry and land on an exit
    # call).
    _exit_commands = Instance(set,
                              args=([
                                  'Quit', 'quit', 'Exit', 'exit', '%Quit',
                                  '%quit', '%Exit', '%exit'
                              ], ))

    def __init__(self, shell, config=None, **traits):
        """Create a new history manager associated with a shell instance.
        """
        # We need a pointer back to the shell for various tasks.
        super(HistoryManager, self).__init__(shell=shell,
                                             config=config,
                                             **traits)

        if self.hist_file == u'':
            # No one has set the hist_file, yet.
            if shell.profile:
                histfname = 'history-%s' % shell.profile
            else:
                histfname = 'history'
            self.hist_file = os.path.join(shell.ipython_dir,
                                          histfname + '.sqlite')

        try:
            self.init_db()
        except sqlite3.DatabaseError:
            if os.path.isfile(self.hist_file):
                # Try to move the file out of the way.
                newpath = os.path.join(self.shell.ipython_dir,
                                       "hist-corrupt.sqlite")
                os.rename(self.hist_file, newpath)
                print("ERROR! History file wasn't a valid SQLite database.",
                      "It was moved to %s" % newpath,
                      "and a new file created.")
                self.init_db()
            else:
                # The hist_file is probably :memory: or something else.
                raise

        self.new_session()

    def init_db(self):
        """Connect to the database, and create tables if necessary."""
        self.db = sqlite3.connect(self.hist_file)
        self.db.execute("""CREATE TABLE IF NOT EXISTS sessions (session integer
                        primary key autoincrement, start timestamp,
                        end timestamp, num_cmds integer, remark text)""")
        self.db.execute("""CREATE TABLE IF NOT EXISTS history 
                (session integer, line integer, source text, source_raw text,
                PRIMARY KEY (session, line))""")
        # Output history is optional, but ensure the table's there so it can be
        # enabled later.
        self.db.execute("""CREATE TABLE IF NOT EXISTS output_history
                        (session integer, line integer, output text,
                        PRIMARY KEY (session, line))""")
        self.db.commit()

    def new_session(self):
        """Get a new session number."""
        with self.db:
            cur = self.db.execute(
                """INSERT INTO sessions VALUES (NULL, ?, NULL,
                            NULL, "") """, (datetime.datetime.now(), ))
            self.session_number = cur.lastrowid

    def end_session(self):
        """Close the database session, filling in the end time and line count."""
        self.writeout_cache()
        with self.db:
            self.db.execute(
                """UPDATE sessions SET end=?, num_cmds=? WHERE
                            session==?""",
                (datetime.datetime.now(), len(self.input_hist_parsed) - 1,
                 self.session_number))
        self.session_number = 0

    def name_session(self, name):
        """Give the current session a name in the history database."""
        with self.db:
            self.db.execute("UPDATE sessions SET remark=? WHERE session==?",
                            (name, self.session_number))

    def reset(self, new_session=True):
        """Clear the session history, releasing all object references, and
        optionally open a new session."""
        if self.session_number:
            self.end_session()
        self.input_hist_parsed[:] = [""]
        self.input_hist_raw[:] = [""]
        self.output_hist.clear()
        # The directory history can't be completely empty
        self.dir_hist[:] = [os.getcwd()]

        if new_session:
            self.new_session()

    ## -------------------------------
    ## Methods for retrieving history:
    ## -------------------------------
    def _run_sql(self, sql, params, raw=True, output=False):
        """Prepares and runs an SQL query for the history database.
        
        Parameters
        ----------
        sql : str
          Any filtering expressions to go after SELECT ... FROM ...
        params : tuple
          Parameters passed to the SQL query (to replace "?")
        raw, output : bool
          See :meth:`get_range`
        
        Returns
        -------
        Tuples as :meth:`get_range`
        """
        toget = 'source_raw' if raw else 'source'
        sqlfrom = "history"
        if output:
            sqlfrom = "history LEFT JOIN output_history USING (session, line)"
            toget = "history.%s, output_history.output" % toget
        cur = self.db.execute("SELECT session, line, %s FROM %s " %\
                                (toget, sqlfrom) + sql, params)
        if output:  # Regroup into 3-tuples, and parse JSON
            loads = lambda out: json.loads(out) if out else None
            return ((ses, lin, (inp, loads(out))) \
                                        for ses, lin, inp, out in cur)
        return cur

    def get_tail(self, n=10, raw=True, output=False, include_latest=False):
        """Get the last n lines from the history database.
        
        Parameters
        ----------
        n : int
          The number of lines to get
        raw, output : bool
          See :meth:`get_range`
        include_latest : bool
          If False (default), n+1 lines are fetched, and the latest one
          is discarded. This is intended to be used where the function
          is called by a user command, which it should not return.
        
        Returns
        -------
        Tuples as :meth:`get_range`
        """
        self.writeout_cache()
        if not include_latest:
            n += 1
        cur = self._run_sql("ORDER BY session DESC, line DESC LIMIT ?", (n, ),
                            raw=raw,
                            output=output)
        if not include_latest:
            return reversed(list(cur)[1:])
        return reversed(list(cur))

    def search(self, pattern="*", raw=True, search_raw=True, output=False):
        """Search the database using unix glob-style matching (wildcards
        * and ?).
        
        Parameters
        ----------
        pattern : str
          The wildcarded pattern to match when searching
        search_raw : bool
          If True, search the raw input, otherwise, the parsed input
        raw, output : bool
          See :meth:`get_range`
        
        Returns
        -------
        Tuples as :meth:`get_range`
        """
        tosearch = "source_raw" if search_raw else "source"
        if output:
            tosearch = "history." + tosearch
        self.writeout_cache()
        return self._run_sql("WHERE %s GLOB ?" % tosearch, (pattern, ),
                             raw=raw,
                             output=output)

    def _get_range_session(self, start=1, stop=None, raw=True, output=False):
        """Get input and output history from the current session. Called by
        get_range, and takes similar parameters."""
        input_hist = self.input_hist_raw if raw else self.input_hist_parsed

        n = len(input_hist)
        if start < 0:
            start += n
        if not stop:
            stop = n
        elif stop < 0:
            stop += n

        for i in range(start, stop):
            if output:
                line = (input_hist[i], self.output_hist_reprs.get(i))
            else:
                line = input_hist[i]
            yield (0, i, line)

    def get_range(self, session=0, start=1, stop=None, raw=True, output=False):
        """Retrieve input by session.
        
        Parameters
        ----------
        session : int
            Session number to retrieve. The current session is 0, and negative
            numbers count back from current session, so -1 is previous session.
        start : int
            First line to retrieve.
        stop : int
            End of line range (excluded from output itself). If None, retrieve
            to the end of the session.
        raw : bool
            If True, return untranslated input
        output : bool
            If True, attempt to include output. This will be 'real' Python
            objects for the current session, or text reprs from previous
            sessions if db_log_output was enabled at the time. Where no output
            is found, None is used.
            
        Returns
        -------
        An iterator over the desired lines. Each line is a 3-tuple, either
        (session, line, input) if output is False, or
        (session, line, (input, output)) if output is True.
        """
        if session == 0 or session == self.session_number:  # Current session
            return self._get_range_session(start, stop, raw, output)
        if session < 0:
            session += self.session_number

        if stop:
            lineclause = "line >= ? AND line < ?"
            params = (session, start, stop)
        else:
            lineclause = "line>=?"
            params = (session, start)

        return self._run_sql("WHERE session==? AND %s"
                             "" % lineclause,
                             params,
                             raw=raw,
                             output=output)

    def get_range_by_str(self, rangestr, raw=True, output=False):
        """Get lines of history from a string of ranges, as used by magic
        commands %hist, %save, %macro, etc.
        
        Parameters
        ----------
        rangestr : str
          A string specifying ranges, e.g. "5 ~2/1-4". See
          :func:`magic_history` for full details.
        raw, output : bool
          As :meth:`get_range`
          
        Returns
        -------
        Tuples as :meth:`get_range`
        """
        for sess, s, e in extract_hist_ranges(rangestr):
            for line in self.get_range(sess, s, e, raw=raw, output=output):
                yield line

    ## ----------------------------
    ## Methods for storing history:
    ## ----------------------------
    def store_inputs(self, line_num, source, source_raw=None):
        """Store source and raw input in history and create input cache
        variables _i*.
        
        Parameters
        ----------
        line_num : int
          The prompt number of this input.
        
        source : str
          Python input.

        source_raw : str, optional
          If given, this is the raw input without any IPython transformations
          applied to it.  If not given, ``source`` is used.
        """
        if source_raw is None:
            source_raw = source
        source = source.rstrip('\n')
        source_raw = source_raw.rstrip('\n')

        # do not store exit/quit commands
        if source_raw.strip() in self._exit_commands:
            return

        self.input_hist_parsed.append(source)
        self.input_hist_raw.append(source_raw)

        self.db_input_cache.append((line_num, source, source_raw))
        # Trigger to flush cache and write to DB.
        if len(self.db_input_cache) >= self.db_cache_size:
            self.writeout_cache()

        # update the auto _i variables
        self._iii = self._ii
        self._ii = self._i
        self._i = self._i00
        self._i00 = source_raw

        # hackish access to user namespace to create _i1,_i2... dynamically
        new_i = '_i%s' % line_num
        to_main = {
            '_i': self._i,
            '_ii': self._ii,
            '_iii': self._iii,
            new_i: self._i00
        }
        self.shell.user_ns.update(to_main)

    def store_output(self, line_num):
        """If database output logging is enabled, this saves all the
        outputs from the indicated prompt number to the database. It's
        called by run_cell after code has been executed.
        
        Parameters
        ----------
        line_num : int
          The line number from which to save outputs
        """
        if (not self.db_log_output) or not self.output_hist_reprs[line_num]:
            return
        output = json.dumps(self.output_hist_reprs[line_num])

        self.db_output_cache.append((line_num, output))
        if self.db_cache_size <= 1:
            self.writeout_cache()

    def _writeout_input_cache(self):
        for line in self.db_input_cache:
            with self.db:
                self.db.execute("INSERT INTO history VALUES (?, ?, ?, ?)",
                                (self.session_number, ) + line)

    def _writeout_output_cache(self):
        for line in self.db_output_cache:
            with self.db:
                self.db.execute("INSERT INTO output_history VALUES (?, ?, ?)",
                                (self.session_number, ) + line)

    def writeout_cache(self):
        """Write any entries in the cache to the database."""
        try:
            self._writeout_input_cache()
        except sqlite3.IntegrityError:
            self.new_session()
            print("ERROR! Session/line number was not unique in",
                  "database. History logging moved to new session",
                  self.session_number)
            try:  # Try writing to the new session. If this fails, don't recurse
                self.writeout_cache()
            except sqlite3.IntegrityError:
                pass
        finally:
            self.db_input_cache = []

        try:
            self._writeout_output_cache()
        except sqlite3.IntegrityError:
            print("!! Session/line number for output was not unique",
                  "in database. Output will not be stored.")
        finally:
            self.db_output_cache = []
Esempio n. 10
0
class Modeler(Device):
    name = 'model'
    path = 'msmaccelerator.model.modeler.Modeler'
    short_description = 'Run the modeler, building an MSM on the available data'
    long_description = '''This device will connect to the msmaccelerator server,
        request the currently available data and build an MSM. That MSM will be
        used by the server to drive future rounds of adaptive sampling.
        Currently, you can use either RMSD (built-in) or a custom distance metric
        (provide the pickle file) with K-centers clustering algorithm.'''

    stride = Int(1,
                 config=True,
                 help='''Subsample data by taking only
        every stride-th point''')
    topology_pdb = FilePath(config=True,
                            extension='.pdb',
                            help='''PDB file
        giving the topology of the system''')
    lag_time = Int(1,
                   config=True,
                   help='''Lag time for building the
        model, in units of the stride. Currently, we are not doing the step
        in MSMBuilder that is refered to as "assignment", where you assign
        the remaining data that was not used during clustering to the cluster
        centers that were identified.''')
    rmsd_atom_indices = FilePath(
        'AtomIndices.dat',
        extension='.dat',
        config=True,
        help='''File containing the indices of atoms to use in the RMSD
        computation. Using a PDB as input, this file can be created with
        the MSMBuilder script CreateAtomIndices.py''')
    clustering_distance_cutoff = Float(0.2,
                                       config=True,
                                       help='''Distance cutoff for
        clustering, in nanometers. We will continue to create new clusters
        until each data point is within this cutoff from its cluster center.'''
                                       )
    symmetrize = Enum(
        ['MLE', 'Transpose', None],
        default='MLE',
        config=True,
        help='''Symmetrization method for constructing the reversibile counts
        matrix.''')
    ergodic_trimming = Bool(False,
                            config=True,
                            help='''Do ergodic trimming when
        constructing the Markov state model. This is generally a good idea for
        building MSMs in the high-data regime where you wish to prevent transitions
        that appear nonergodic because they've been undersampled from influencing
        your model, but is inappropriate in the sparse-data regime when you're
        using min-counts sampling, because these are precisiely the states that
        you're most interested in.''')
    use_custom_metric = Bool(False,
                             config=True,
                             help='''Should we use
         a custom distance metric for clusering instead of RMSD?''')
    custom_metric_path = Unicode('metric.pickl',
                                 config=True,
                                 help='''File
         containing a pickled metric for use in clustering.''')
    clusterer = Enum(['kcenters', 'hybrid', 'ward'],
                     default='kcenters',
                     config=True,
                     help='''The method used for clustering structures in
        the MSM.''')

    aliases = dict(
        stride='Modeler.stride',
        lag_time='Modeler.lag_time',
        rmsd_atom_indices='Modeler.rmsd_atom_indices',
        clustering_distance_cutoff='Modeler.clustering_distance_cutoff',
        topology_pdb='Modeler.topology_pdb',
        symmetrize='Modeler.symmetrize',
        trim='Modeler.ergodic_trimming',
        zmq_url='Device.zmq_url',
        zmq_port='Device.zmq_port')

    def on_startup_message(self, msg):
        """This method is called when the device receives its startup message
        from the server
        """
        assert msg.header.msg_type in ['construct_model'
                                       ], 'only allowed methods'
        return getattr(self, msg.header.msg_type)(msg.header, msg.content)

    def construct_model(self, header, content):
        """All the model building code. This code is what's called by the
        server after registration."""
        # the message needs to not contain unicode
        assert content.output.protocol == 'localfs', "I'm currently only equipped for localfs output"

        # load up all of the trajectories
        trajs = self.load_trajectories(content.traj_fns)

        # run clustering
        assignments, generator_indices = self.cluster(trajs)

        # build the MSM
        counts, rev_counts, t_matrix, populations, mapping = self.build_msm(
            assignments)

        # save the results to disk
        msm = MarkovStateModel(counts=counts,
                               reversible_counts=rev_counts,
                               transition_matrix=t_matrix,
                               populations=populations,
                               mapping=mapping,
                               generator_indices=generator_indices,
                               traj_filenames=content.traj_fns,
                               assignments_stride=self.stride,
                               lag_time=self.lag_time,
                               assignments=assignments)
        msm.save(content.output.path)

        # tell the server that we're done
        self.send_recv(msg_type='modeler_done',
                       content={
                           'status': 'success',
                           'output': {
                               'protocol': 'localfs',
                               'path': content.output.path
                           },
                       })

    def load_trajectories(self, traj_fns):
        """Load up the trajectories, taking into account both the stride and
        the atom indices"""

        trajs = []
        if os.path.exists(self.rmsd_atom_indices):
            self.log.info('Loading atom indices from %s',
                          self.rmsd_atom_indices)
            atom_indices = np.loadtxt(self.rmsd_atom_indices, dtype=np.int)
        else:
            self.log.info('Skipping loading atom_indices. Using all.')
            atom_indices = None

        for traj_fn in traj_fns:
            # use the mdtraj dcd reader, but then monkey-patch
            # the coordinate array into shim for the msmbuilder clustering
            # code that wants the trajectory to act like a dict with the XYZList
            # key.
            self.log.info('Loading traj %s', traj_fn)
            if not os.path.exists(traj_fn):
                self.log.error(
                    'Traj file reported by server does not exist: %s' %
                    traj_fn)
                continue

            t = mdtraj.trajectory.load(traj_fn,
                                       atom_indices=atom_indices,
                                       top=self.topology_pdb)
            t2 = ShimTrajectory(t.xyz[::self.stride, :])

            trajs.append(t2)

        if len(trajs) == 0:
            raise ValueError('No trajectories found!')

        self.log.info('loaded %s trajectories', len(trajs))
        self.log.info('loaded %s total frames...', sum(len(t) for t in trajs))
        self.log.info('loaded %s atoms', t2['XYZList'].shape[1])

        return trajs

    def cluster(self, trajectories):
        """Cluster the trajectories into microstates.

        Returns
        -------
        assignments : np.ndarray, dtype=int, shape=[n_trajs, max_n_frames]
            assignments is a 2d arry giving the microstate that each frame
            from the simulation is assigned to. The indexing semantics are
            a little bit nontrivial because of the striding and the lag time.
            They are that assignments[i,j]=k means that in the `ith` trajectory,
            the `j*self.stride`th frame is assiged to microstate `k`.
        generator_indices : np.ndarray, dtype=int, shape=[n_clusters, 2]
            This array gives the indices of the clusters centers, with respect
            to their position in the trajectories on disk. the semantics are
            that generator_indices[i, :]=[k,l] means that the `ith` cluster's center
            is in trajectory `k`, in its `l`th frame. Because of the striding,
            `l` will always be a multiple of `self.stride`.
        """
        if self.use_custom_metric:
            metric_path = self.custom_metric_path
            self.log.info("Loading custom metric: %s" % metric_path)
            pickle_file = open(metric_path)
            metric = pickle.load(pickle_file)
        else:
            metric = msmbuilder.metrics.RMSD()

        if self.clusterer == 'kcenters':
            # Use k-centers clustering
            clusterer = msmbuilder.clustering.KCenters(
                metric,
                trajectories,
                distance_cutoff=self.clustering_distance_cutoff)
            assignments = clusterer.get_assignments()
        elif self.clusterer == 'ward':
            # Use ward clustering
            clusterer = msmbuilder.clustering.Hierarchical(metric,
                                                           trajectories,
                                                           method='ward')
            assignments = clusterer.get_assignments(
                self.clustering_distance_cutoff)
        elif self.clusterer == 'hybrid':
            # Use hybrid k-medoids clustering
            clusterer = msmbuilder.clustering.HybridKMedoids(
                metric,
                trajectories,
                k=None,
                distance_cutoff=self.clustering_distance_cutoff)
            assignments = clusterer.get_assignments()
        else:
            self.log.error("Please choose an actual clusterer")

        # if we get the generators as a trajectory, it will only
        # have the reduced set of atoms.

        # the clusterer contains indices with respect to the concatenated trajectory
        # inside the clusterer object. we need to reindex to get the
        # traj/frame index of each generator
        # print 'generator longindices', clusterer._generator_indices
        # print 'traj lengths         ', clusterer._traj_lengths
        generator_indices = reindex_list(clusterer._generator_indices,
                                         clusterer._traj_lengths)
        # print 'generator indices', generator_indices

        # but these indices are still with respect to the traj/frame
        # after striding, so we need to unstride them
        generator_indices[:, 1] *= self.stride

        # print generator_indices

        return assignments, generator_indices

    def build_msm(self, assignments):
        """Build the MSM from the microstate assigned trajectories"""
        counts = msmbuilder.MSMLib.get_count_matrix_from_assignments(
            assignments, lag_time=self.lag_time)

        result = msmbuilder.MSMLib.build_msm(
            counts,
            symmetrize=self.symmetrize,
            ergodic_trimming=self.ergodic_trimming)
        # unpack the results
        rev_counts, t_matrix, populations, mapping = result
        return counts, rev_counts, t_matrix, populations, mapping
Esempio n. 11
0
class Widget(LoggingConfigurable):
    #-------------------------------------------------------------------------
    # Class attributes
    #-------------------------------------------------------------------------
    _widget_construction_callback = None
    widgets = {}

    @staticmethod
    def on_widget_constructed(callback):
        """Registers a callback to be called when a widget is constructed.

        The callback must have the following signature:
        callback(widget)"""
        Widget._widget_construction_callback = callback

    @staticmethod
    def _call_widget_constructed(widget):
        """Static method, called when a widget is constructed."""
        if Widget._widget_construction_callback is not None and callable(
                Widget._widget_construction_callback):
            Widget._widget_construction_callback(widget)

    #-------------------------------------------------------------------------
    # Traits
    #-------------------------------------------------------------------------
    _model_name = Unicode('WidgetModel',
                          help="""Name of the backbone model 
        registered in the front-end to create and sync this widget with.""")
    _view_name = Unicode('WidgetView',
                         help="""Default view registered in the front-end
        to use to represent the widget.""",
                         sync=True)
    comm = Instance('IPython.kernel.comm.Comm')

    msg_throttle = Int(3,
                       sync=True,
                       help="""Maximum number of msgs the 
        front-end can send before receiving an idle msg from the back-end.""")

    keys = List()

    def _keys_default(self):
        return [name for name in self.traits(sync=True)]

    _property_lock = Tuple((None, None))
    _send_state_lock = Int(0)
    _states_to_send = Set(allow_none=False)
    _display_callbacks = Instance(CallbackDispatcher, ())
    _msg_callbacks = Instance(CallbackDispatcher, ())

    #-------------------------------------------------------------------------
    # (Con/de)structor
    #-------------------------------------------------------------------------
    def __init__(self, **kwargs):
        """Public constructor"""
        self._model_id = kwargs.pop('model_id', None)
        super(Widget, self).__init__(**kwargs)

        self.on_trait_change(self._handle_property_changed, self.keys)
        Widget._call_widget_constructed(self)
        self.open()

    def __del__(self):
        """Object disposal"""
        self.close()

    #-------------------------------------------------------------------------
    # Properties
    #-------------------------------------------------------------------------

    def open(self):
        """Open a comm to the frontend if one isn't already open."""
        if self.comm is None:
            if self._model_id is None:
                self.comm = Comm(target_name=self._model_name)
                self._model_id = self.model_id
            else:
                self.comm = Comm(target_name=self._model_name,
                                 comm_id=self._model_id)
            self.comm.on_msg(self._handle_msg)
            Widget.widgets[self.model_id] = self

            # first update
            self.send_state()

    @property
    def model_id(self):
        """Gets the model id of this widget.

        If a Comm doesn't exist yet, a Comm will be created automagically."""
        return self.comm.comm_id

    #-------------------------------------------------------------------------
    # Methods
    #-------------------------------------------------------------------------

    def close(self):
        """Close method.

        Closes the underlying comm.
        When the comm is closed, all of the widget views are automatically
        removed from the front-end."""
        if self.comm is not None:
            Widget.widgets.pop(self.model_id, None)
            self.comm.close()
            self.comm = None

    def send_state(self, key=None):
        """Sends the widget state, or a piece of it, to the front-end.

        Parameters
        ----------
        key : unicode, or iterable (optional)
            A single property's name or iterable of property names to sync with the front-end.
        """
        self._send({"method": "update", "state": self.get_state(key=key)})

    def get_state(self, key=None):
        """Gets the widget state, or a piece of it.

        Parameters
        ----------
        key : unicode or iterable (optional)
            A single property's name or iterable of property names to get.
        """
        if key is None:
            keys = self.keys
        elif isinstance(key, string_types):
            keys = [key]
        elif isinstance(key, collections.Iterable):
            keys = key
        else:
            raise ValueError(
                "key must be a string, an iterable of keys, or None")
        state = {}
        for k in keys:
            f = self.trait_metadata(k, 'to_json', self._trait_to_json)
            value = getattr(self, k)
            state[k] = f(value)
        return state

    def send(self, content):
        """Sends a custom msg to the widget model in the front-end.

        Parameters
        ----------
        content : dict
            Content of the message to send.
        """
        self._send({"method": "custom", "content": content})

    def on_msg(self, callback, remove=False):
        """(Un)Register a custom msg receive callback.

        Parameters
        ----------
        callback: callable
            callback will be passed two arguments when a message arrives::
            
                callback(widget, content)
            
        remove: bool
            True if the callback should be unregistered."""
        self._msg_callbacks.register_callback(callback, remove=remove)

    def on_displayed(self, callback, remove=False):
        """(Un)Register a widget displayed callback.

        Parameters
        ----------
        callback: method handler
            Must have a signature of::
            
                callback(widget, **kwargs)
            
            kwargs from display are passed through without modification.
        remove: bool
            True if the callback should be unregistered."""
        self._display_callbacks.register_callback(callback, remove=remove)

    #-------------------------------------------------------------------------
    # Support methods
    #-------------------------------------------------------------------------
    @contextmanager
    def _lock_property(self, key, value):
        """Lock a property-value pair.

        The value should be the JSON state of the property.

        NOTE: This, in addition to the single lock for all state changes, is 
        flawed.  In the future we may want to look into buffering state changes 
        back to the front-end."""
        self._property_lock = (key, value)
        try:
            yield
        finally:
            self._property_lock = (None, None)

    @contextmanager
    def hold_sync(self):
        """Hold syncing any state until the context manager is released"""
        # We increment a value so that this can be nested.  Syncing will happen when
        # all levels have been released.
        self._send_state_lock += 1
        try:
            yield
        finally:
            self._send_state_lock -= 1
            if self._send_state_lock == 0:
                self.send_state(self._states_to_send)
                self._states_to_send.clear()

    def _should_send_property(self, key, value):
        """Check the property lock (property_lock)"""
        to_json = self.trait_metadata(key, 'to_json', self._trait_to_json)
        if (key == self._property_lock[0]
                and to_json(value) == self._property_lock[1]):
            return False
        elif self._send_state_lock > 0:
            self._states_to_send.add(key)
            return False
        else:
            return True

    # Event handlers
    @_show_traceback
    def _handle_msg(self, msg):
        """Called when a msg is received from the front-end"""
        data = msg['content']['data']
        method = data['method']
        if not method in ['backbone', 'custom']:
            self.log.error(
                'Unknown front-end to back-end widget msg with method "%s"' %
                method)

        # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one.
        if method == 'backbone' and 'sync_data' in data:
            sync_data = data['sync_data']
            self._handle_receive_state(sync_data)  # handles all methods

        # Handle a custom msg from the front-end
        elif method == 'custom':
            if 'content' in data:
                self._handle_custom_msg(data['content'])

    def _handle_receive_state(self, sync_data):
        """Called when a state is received from the front-end."""
        for name in self.keys:
            if name in sync_data:
                json_value = sync_data[name]
                from_json = self.trait_metadata(name, 'from_json',
                                                self._trait_from_json)
                with self._lock_property(name, json_value):
                    setattr(self, name, from_json(json_value))

    def _handle_custom_msg(self, content):
        """Called when a custom msg is received."""
        self._msg_callbacks(self, content)

    def _handle_property_changed(self, name, old, new):
        """Called when a property has been changed."""
        # Make sure this isn't information that the front-end just sent us.
        if self._should_send_property(name, new):
            # Send new state to front-end
            self.send_state(key=name)

    def _handle_displayed(self, **kwargs):
        """Called when a view has been displayed for this widget instance"""
        self._display_callbacks(self, **kwargs)

    def _trait_to_json(self, x):
        """Convert a trait value to json

        Traverse lists/tuples and dicts and serialize their values as well.
        Replace any widgets with their model_id
        """
        if isinstance(x, dict):
            return {k: self._trait_to_json(v) for k, v in x.items()}
        elif isinstance(x, (list, tuple)):
            return [self._trait_to_json(v) for v in x]
        elif isinstance(x, Widget):
            return "IPY_MODEL_" + x.model_id
        else:
            return x  # Value must be JSON-able

    def _trait_from_json(self, x):
        """Convert json values to objects

        Replace any strings representing valid model id values to Widget references.
        """
        if isinstance(x, dict):
            return {k: self._trait_from_json(v) for k, v in x.items()}
        elif isinstance(x, (list, tuple)):
            return [self._trait_from_json(v) for v in x]
        elif isinstance(x, string_types) and x.startswith(
                'IPY_MODEL_') and x[10:] in Widget.widgets:
            # we want to support having child widgets at any level in a hierarchy
            # trusting that a widget UUID will not appear out in the wild
            return Widget.widgets[x[10:]]
        else:
            return x

    def _ipython_display_(self, **kwargs):
        """Called when `IPython.display.display` is called on the widget."""
        # Show view.  By sending a display message, the comm is opened and the
        # initial state is sent.
        self._send({"method": "display"})
        self._handle_displayed(**kwargs)

    def _send(self, msg):
        """Sends a message to the model in the front-end."""
        self.comm.send(msg)
Esempio n. 12
0
class IPythonQtConsoleApp(BaseIPythonApplication):
    name = 'ipython-qtconsole'
    default_config_file_name = 'ipython_config.py'

    description = """
        The IPython QtConsole.
        
        This launches a Console-style application using Qt.  It is not a full
        console, in that launched terminal subprocesses will not be able to accept
        input.
        
        The QtConsole supports various extra features beyond the Terminal IPython
        shell, such as inline plotting with matplotlib, via:
        
            ipython qtconsole --pylab=inline
        
        as well as saving your session as HTML, and printing the output.
        
    """
    examples = _examples

    classes = [
        IPKernelApp, IPythonWidget, ZMQInteractiveShell, ProfileDir, Session
    ]
    flags = Dict(flags)
    aliases = Dict(aliases)

    kernel_argv = List(Unicode)

    # create requested profiles by default, if they don't exist:
    auto_create = CBool(True)
    # connection info:
    ip = Unicode(LOCALHOST,
                 config=True,
                 help="""Set the kernel\'s IP address [default localhost].
        If the IP address is something other than localhost, then
        Consoles on other machines will be able to connect
        to the Kernel, so be careful!""")

    sshserver = Unicode(
        '',
        config=True,
        help="""The SSH server to use to connect to the kernel.""")
    sshkey = Unicode(
        '',
        config=True,
        help="""Path to the ssh key to use for logging in to the ssh server."""
    )

    hb_port = Int(0,
                  config=True,
                  help="set the heartbeat port [default: random]")
    shell_port = Int(0,
                     config=True,
                     help="set the shell (XREP) port [default: random]")
    iopub_port = Int(0,
                     config=True,
                     help="set the iopub (PUB) port [default: random]")
    stdin_port = Int(0,
                     config=True,
                     help="set the stdin (XREQ) port [default: random]")
    connection_file = Unicode(
        '',
        config=True,
        help=
        """JSON file in which to store connection info [default: kernel-<pid>.json]

        This file will contain the IP, ports, and authentication key needed to connect
        clients to this kernel. By default, this file will be created in the security-dir
        of the current profile, but can be specified by absolute path.
        """)

    def _connection_file_default(self):
        return 'kernel-%i.json' % os.getpid()

    existing = Unicode('',
                       config=True,
                       help="""Connect to an already running kernel""")

    stylesheet = Unicode('',
                         config=True,
                         help="path to a custom CSS stylesheet")

    pure = CBool(False,
                 config=True,
                 help="Use a pure Python kernel instead of an IPython kernel.")
    plain = CBool(
        False,
        config=True,
        help=
        "Use a plaintext widget instead of rich text (plain can't print/save)."
    )

    def _pure_changed(self, name, old, new):
        kind = 'plain' if self.plain else 'rich'
        self.config.ConsoleWidget.kind = kind
        if self.pure:
            self.widget_factory = FrontendWidget
        elif self.plain:
            self.widget_factory = IPythonWidget
        else:
            self.widget_factory = RichIPythonWidget

    _plain_changed = _pure_changed

    confirm_exit = CBool(
        True,
        config=True,
        help="""
        Set to display confirmation dialog on exit. You can always use 'exit' or 'quit',
        to force a direct exit without any confirmation.""",
    )

    # the factory for creating a widget
    widget_factory = Any(RichIPythonWidget)

    def parse_command_line(self, argv=None):
        super(IPythonQtConsoleApp, self).parse_command_line(argv)
        if argv is None:
            argv = sys.argv[1:]
        self.kernel_argv = list(argv)  # copy
        # kernel should inherit default config file from frontend
        self.kernel_argv.append("--KernelApp.parent_appname='%s'" % self.name)
        # Scrub frontend-specific flags
        swallow_next = False
        was_flag = False
        # copy again, in case some aliases have the same name as a flag
        # argv = list(self.kernel_argv)
        for a in argv:
            if swallow_next:
                swallow_next = False
                # last arg was an alias, remove the next one
                # *unless* the last alias has a no-arg flag version, in which
                # case, don't swallow the next arg if it's also a flag:
                if not (was_flag and a.startswith('-')):
                    self.kernel_argv.remove(a)
                    continue
            if a.startswith('-'):
                split = a.lstrip('-').split('=')
                alias = split[0]
                if alias in qt_aliases:
                    self.kernel_argv.remove(a)
                    if len(split) == 1:
                        # alias passed with arg via space
                        swallow_next = True
                        # could have been a flag that matches an alias, e.g. `existing`
                        # in which case, we might not swallow the next arg
                        was_flag = alias in qt_flags
                elif alias in qt_flags:
                    # strip flag, but don't swallow next, as flags don't take args
                    self.kernel_argv.remove(a)

    def init_connection_file(self):
        """find the connection file, and load the info if found.
        
        The current working directory and the current profile's security
        directory will be searched for the file if it is not given by
        absolute path.
        
        When attempting to connect to an existing kernel and the `--existing`
        argument does not match an existing file, it will be interpreted as a
        fileglob, and the matching file in the current profile's security dir
        with the latest access time will be used.
        """
        if self.existing:
            try:
                cf = find_connection_file(self.existing)
            except Exception:
                self.log.critical(
                    "Could not find existing kernel connection file %s",
                    self.existing)
                self.exit(1)
            self.log.info("Connecting to existing kernel: %s" % cf)
            self.connection_file = cf
        # should load_connection_file only be used for existing?
        # as it is now, this allows reusing ports if an existing
        # file is requested
        try:
            self.load_connection_file()
        except Exception:
            self.log.error("Failed to load connection file: %r",
                           self.connection_file,
                           exc_info=True)
            self.exit(1)

    def load_connection_file(self):
        """load ip/port/hmac config from JSON connection file"""
        # this is identical to KernelApp.load_connection_file
        # perhaps it can be centralized somewhere?
        try:
            fname = filefind(self.connection_file,
                             ['.', self.profile_dir.security_dir])
        except IOError:
            self.log.debug("Connection File not found: %s",
                           self.connection_file)
            return
        self.log.debug(u"Loading connection file %s", fname)
        with open(fname) as f:
            s = f.read()
        cfg = json.loads(s)
        if self.ip == LOCALHOST and 'ip' in cfg:
            # not overridden by config or cl_args
            self.ip = cfg['ip']
        for channel in ('hb', 'shell', 'iopub', 'stdin'):
            name = channel + '_port'
            if getattr(self, name) == 0 and name in cfg:
                # not overridden by config or cl_args
                setattr(self, name, cfg[name])
        if 'key' in cfg:
            self.config.Session.key = str_to_bytes(cfg['key'])

    def init_ssh(self):
        """set up ssh tunnels, if needed."""
        if not self.sshserver and not self.sshkey:
            return

        if self.sshkey and not self.sshserver:
            # specifying just the key implies that we are connecting directly
            self.sshserver = self.ip
            self.ip = LOCALHOST

        # build connection dict for tunnels:
        info = dict(ip=self.ip,
                    shell_port=self.shell_port,
                    iopub_port=self.iopub_port,
                    stdin_port=self.stdin_port,
                    hb_port=self.hb_port)

        self.log.info("Forwarding connections to %s via %s" %
                      (self.ip, self.sshserver))

        # tunnels return a new set of ports, which will be on localhost:
        self.ip = LOCALHOST
        try:
            newports = tunnel_to_kernel(info, self.sshserver, self.sshkey)
        except:
            # even catch KeyboardInterrupt
            self.log.error("Could not setup tunnels", exc_info=True)
            self.exit(1)

        self.shell_port, self.iopub_port, self.stdin_port, self.hb_port = newports

        cf = self.connection_file
        base, ext = os.path.splitext(cf)
        base = os.path.basename(base)
        self.connection_file = os.path.basename(base) + '-ssh' + ext
        self.log.critical("To connect another client via this tunnel, use:")
        self.log.critical("--existing %s" % self.connection_file)

    def _new_connection_file(self):
        return os.path.join(self.profile_dir.security_dir,
                            'kernel-%s.json' % uuid.uuid4())

    def init_kernel_manager(self):
        # Don't let Qt or ZMQ swallow KeyboardInterupts.
        signal.signal(signal.SIGINT, signal.SIG_DFL)
        sec = self.profile_dir.security_dir
        try:
            cf = filefind(self.connection_file, ['.', sec])
        except IOError:
            # file might not exist
            if self.connection_file == os.path.basename(self.connection_file):
                # just shortname, put it in security dir
                cf = os.path.join(sec, self.connection_file)
            else:
                cf = self.connection_file

        # Create a KernelManager and start a kernel.
        self.kernel_manager = QtKernelManager(
            ip=self.ip,
            shell_port=self.shell_port,
            iopub_port=self.iopub_port,
            stdin_port=self.stdin_port,
            hb_port=self.hb_port,
            connection_file=cf,
            config=self.config,
        )
        # start the kernel
        if not self.existing:
            kwargs = dict(ipython=not self.pure)
            kwargs['extra_arguments'] = self.kernel_argv
            self.kernel_manager.start_kernel(**kwargs)
        elif self.sshserver:
            # ssh, write new connection file
            self.kernel_manager.write_connection_file()
        self.kernel_manager.start_channels()

    def new_frontend_master(self):
        """ Create and return new frontend attached to new kernel, launched on localhost.
        """
        ip = self.ip if self.ip in LOCAL_IPS else LOCALHOST
        kernel_manager = QtKernelManager(
            ip=ip,
            connection_file=self._new_connection_file(),
            config=self.config,
        )
        # start the kernel
        kwargs = dict(ipython=not self.pure)
        kwargs['extra_arguments'] = self.kernel_argv
        kernel_manager.start_kernel(**kwargs)
        kernel_manager.start_channels()
        widget = self.widget_factory(config=self.config, local_kernel=True)
        widget.kernel_manager = kernel_manager
        widget._existing = False
        widget._may_close = True
        widget._confirm_exit = self.confirm_exit
        return widget

    def new_frontend_slave(self, current_widget):
        """Create and return a new frontend attached to an existing kernel.
        
        Parameters
        ----------
        current_widget : IPythonWidget
            The IPythonWidget whose kernel this frontend is to share
        """
        kernel_manager = QtKernelManager(
            connection_file=current_widget.kernel_manager.connection_file,
            config=self.config,
        )
        kernel_manager.load_connection_file()
        kernel_manager.start_channels()
        widget = self.widget_factory(config=self.config, local_kernel=False)
        widget._existing = True
        widget._may_close = False
        widget._confirm_exit = False
        widget.kernel_manager = kernel_manager
        return widget

    def init_qt_elements(self):
        # Create the widget.
        self.app = QtGui.QApplication([])

        base_path = os.path.abspath(os.path.dirname(__file__))
        icon_path = os.path.join(base_path, 'resources', 'icon',
                                 'IPythonConsole.svg')
        self.app.icon = QtGui.QIcon(icon_path)
        QtGui.QApplication.setWindowIcon(self.app.icon)

        local_kernel = (not self.existing) or self.ip in LOCAL_IPS
        self.widget = self.widget_factory(config=self.config,
                                          local_kernel=local_kernel)
        self.widget._existing = self.existing
        self.widget._may_close = not self.existing
        self.widget._confirm_exit = self.confirm_exit

        self.widget.kernel_manager = self.kernel_manager
        self.window = MainWindow(
            self.app,
            confirm_exit=self.confirm_exit,
            new_frontend_factory=self.new_frontend_master,
            slave_frontend_factory=self.new_frontend_slave,
        )
        self.window.log = self.log
        self.window.add_tab_with_frontend(self.widget)
        self.window.init_menu_bar()
        self.window.setWindowTitle('Python' if self.pure else 'IPython')

    def init_colors(self):
        """Configure the coloring of the widget"""
        # Note: This will be dramatically simplified when colors
        # are removed from the backend.

        if self.pure:
            # only IPythonWidget supports styling
            return

        # parse the colors arg down to current known labels
        try:
            colors = self.config.ZMQInteractiveShell.colors
        except AttributeError:
            colors = None
        try:
            style = self.config.IPythonWidget.syntax_style
        except AttributeError:
            style = None

        # find the value for colors:
        if colors:
            colors = colors.lower()
            if colors in ('lightbg', 'light'):
                colors = 'lightbg'
            elif colors in ('dark', 'linux'):
                colors = 'linux'
            else:
                colors = 'nocolor'
        elif style:
            if style == 'bw':
                colors = 'nocolor'
            elif styles.dark_style(style):
                colors = 'linux'
            else:
                colors = 'lightbg'
        else:
            colors = None

        # Configure the style.
        widget = self.widget
        if style:
            widget.style_sheet = styles.sheet_from_template(style, colors)
            widget.syntax_style = style
            widget._syntax_style_changed()
            widget._style_sheet_changed()
        elif colors:
            # use a default style
            widget.set_default_style(colors=colors)
        else:
            # this is redundant for now, but allows the widget's
            # defaults to change
            widget.set_default_style()

        if self.stylesheet:
            # we got an expicit stylesheet
            if os.path.isfile(self.stylesheet):
                with open(self.stylesheet) as f:
                    sheet = f.read()
                widget.style_sheet = sheet
                widget._style_sheet_changed()
            else:
                raise IOError("Stylesheet %r not found." % self.stylesheet)

    @catch_config_error
    def initialize(self, argv=None):
        super(IPythonQtConsoleApp, self).initialize(argv)
        self.init_connection_file()
        default_secure(self.config)
        self.init_ssh()
        self.init_kernel_manager()
        self.init_qt_elements()
        self.init_colors()

    def start(self):

        # draw the window
        self.window.show()

        # Start the application main loop.
        self.app.exec_()
Esempio n. 13
0
class TerminalInteractiveShell(InteractiveShell):

    autoedit_syntax = CBool(False, config=True)
    banner = Str('')
    banner1 = Str(default_banner, config=True)
    banner2 = Str('', config=True)
    confirm_exit = CBool(True, config=True)
    # This display_banner only controls whether or not self.show_banner()
    # is called when mainloop/interact are called.  The default is False
    # because for the terminal based application, the banner behavior
    # is controlled by Global.display_banner, which IPythonApp looks at
    # to determine if *it* should call show_banner() by hand or not.
    display_banner = CBool(False) # This isn't configurable!
    embedded = CBool(False)
    embedded_active = CBool(False)
    editor = Str(get_default_editor(), config=True)
    pager = Str('less', config=True)

    screen_length = Int(0, config=True)
    term_title = CBool(False, config=True)

    def __init__(self, config=None, ipython_dir=None, user_ns=None,
                 user_global_ns=None, custom_exceptions=((),None),
                 usage=None, banner1=None, banner2=None,
                 display_banner=None):

        super(TerminalInteractiveShell, self).__init__(
            config=config, ipython_dir=ipython_dir, user_ns=user_ns,
            user_global_ns=user_global_ns, custom_exceptions=custom_exceptions
        )
        self.init_term_title()
        self.init_usage(usage)
        self.init_banner(banner1, banner2, display_banner)

    #-------------------------------------------------------------------------
    # Things related to the terminal
    #-------------------------------------------------------------------------

    @property
    def usable_screen_length(self):
        if self.screen_length == 0:
            return 0
        else:
            num_lines_bot = self.separate_in.count('\n')+1
            return self.screen_length - num_lines_bot

    def init_term_title(self):
        # Enable or disable the terminal title.
        if self.term_title:
            toggle_set_term_title(True)
            set_term_title('IPython: ' + abbrev_cwd())
        else:
            toggle_set_term_title(False)

    #-------------------------------------------------------------------------
    # Things related to aliases
    #-------------------------------------------------------------------------

    def init_alias(self):
        # The parent class defines aliases that can be safely used with any
        # frontend.
        super(TerminalInteractiveShell, self).init_alias()

        # Now define aliases that only make sense on the terminal, because they
        # need direct access to the console in a way that we can't emulate in
        # GUI or web frontend
        if os.name == 'posix':
            aliases = [('clear', 'clear'), ('more', 'more'), ('less', 'less'),
                       ('man', 'man')]
        elif os.name == 'nt':
            aliases = [('cls', 'cls')]


        for name, cmd in aliases:
            self.alias_manager.define_alias(name, cmd)

    #-------------------------------------------------------------------------
    # Things related to the banner and usage
    #-------------------------------------------------------------------------

    def _banner1_changed(self):
        self.compute_banner()

    def _banner2_changed(self):
        self.compute_banner()

    def _term_title_changed(self, name, new_value):
        self.init_term_title()

    def init_banner(self, banner1, banner2, display_banner):
        if banner1 is not None:
            self.banner1 = banner1
        if banner2 is not None:
            self.banner2 = banner2
        if display_banner is not None:
            self.display_banner = display_banner
        self.compute_banner()

    def show_banner(self, banner=None):
        if banner is None:
            banner = self.banner
        self.write(banner)

    def compute_banner(self):
        self.banner = self.banner1
        if self.profile:
            self.banner += '\nIPython profile: %s\n' % self.profile
        if self.banner2:
            self.banner += '\n' + self.banner2

    def init_usage(self, usage=None):
        if usage is None:
            self.usage = interactive_usage
        else:
            self.usage = usage

    #-------------------------------------------------------------------------
    # Mainloop and code execution logic
    #-------------------------------------------------------------------------

    def mainloop(self, display_banner=None):
        """Start the mainloop.

        If an optional banner argument is given, it will override the
        internally created default banner.
        """
        
        with nested(self.builtin_trap, self.display_trap):

            # if you run stuff with -c <cmd>, raw hist is not updated
            # ensure that it's in sync
            self.history_manager.sync_inputs()

            while 1:
                try:
                    self.interact(display_banner=display_banner)
                    #self.interact_with_readline()                
                    # XXX for testing of a readline-decoupled repl loop, call
                    # interact_with_readline above
                    break
                except KeyboardInterrupt:
                    # this should not be necessary, but KeyboardInterrupt
                    # handling seems rather unpredictable...
                    self.write("\nKeyboardInterrupt in interact()\n")

    def interact(self, display_banner=None):
        """Closely emulate the interactive Python console."""

        # batch run -> do not interact        
        if self.exit_now:
            return

        if display_banner is None:
            display_banner = self.display_banner
        if display_banner:
            self.show_banner()

        more = False
        
        # Mark activity in the builtins
        __builtin__.__dict__['__IPYTHON__active'] += 1
        
        if self.has_readline:
            self.readline_startup_hook(self.pre_readline)
        # exit_now is set by a call to %Exit or %Quit, through the
        # ask_exit callback.

        while not self.exit_now:
            self.hooks.pre_prompt_hook()
            if more:
                try:
                    prompt = self.hooks.generate_prompt(True)
                except:
                    self.showtraceback()
                if self.autoindent:
                    self.rl_do_indent = True
                    
            else:
                try:
                    prompt = self.hooks.generate_prompt(False)
                except:
                    self.showtraceback()
            try:
                line = self.raw_input(prompt)
                if self.exit_now:
                    # quick exit on sys.std[in|out] close
                    break
                if self.autoindent:
                    self.rl_do_indent = False
                    
            except KeyboardInterrupt:
                #double-guard against keyboardinterrupts during kbdint handling
                try:
                    self.write('\nKeyboardInterrupt\n')
                    self.resetbuffer()
                    more = False
                except KeyboardInterrupt:
                    pass
            except EOFError:
                if self.autoindent:
                    self.rl_do_indent = False
                    if self.has_readline:
                        self.readline_startup_hook(None)
                self.write('\n')
                self.exit()
            except bdb.BdbQuit:
                warn('The Python debugger has exited with a BdbQuit exception.\n'
                     'Because of how pdb handles the stack, it is impossible\n'
                     'for IPython to properly format this particular exception.\n'
                     'IPython will resume normal operation.')
            except:
                # exceptions here are VERY RARE, but they can be triggered
                # asynchronously by signal handlers, for example.
                self.showtraceback()
            else:
                self.input_splitter.push(line)
                more = self.input_splitter.push_accepts_more()
                if (self.SyntaxTB.last_syntax_error and
                    self.autoedit_syntax):
                    self.edit_syntax_error()
                if not more:
                    source_raw = self.input_splitter.source_raw_reset()[1]
                    self.run_cell(source_raw)
                
        # We are off again...
        __builtin__.__dict__['__IPYTHON__active'] -= 1

        # Turn off the exit flag, so the mainloop can be restarted if desired
        self.exit_now = False

    def raw_input(self, prompt='', continue_prompt=False):
        """Write a prompt and read a line.

        The returned line does not include the trailing newline.
        When the user enters the EOF key sequence, EOFError is raised.

        Optional inputs:

          - prompt(''): a string to be printed to prompt the user.

          - continue_prompt(False): whether this line is the first one or a
          continuation in a sequence of inputs.
        """
        # Code run by the user may have modified the readline completer state.
        # We must ensure that our completer is back in place.

        if self.has_readline:
            self.set_readline_completer()
        
        try:
            line = raw_input_original(prompt).decode(self.stdin_encoding)
        except ValueError:
            warn("\n********\nYou or a %run:ed script called sys.stdin.close()"
                 " or sys.stdout.close()!\nExiting IPython!")
            self.ask_exit()
            return ""

        # Try to be reasonably smart about not re-indenting pasted input more
        # than necessary.  We do this by trimming out the auto-indent initial
        # spaces, if the user's actual input started itself with whitespace.
        if self.autoindent:
            if num_ini_spaces(line) > self.indent_current_nsp:
                line = line[self.indent_current_nsp:]
                self.indent_current_nsp = 0
            
        # store the unfiltered input before the user has any chance to modify
        # it.
        if line.strip():
            if continue_prompt:
                if self.has_readline and self.readline_use:
                    histlen = self.readline.get_current_history_length()
                    if histlen > 1:
                        newhist = self.history_manager.input_hist_raw[-1].rstrip()
                        self.readline.remove_history_item(histlen-1)
                        self.readline.replace_history_item(histlen-2,
                                        newhist.encode(self.stdin_encoding))
            else:
                self.history_manager.input_hist_raw.append('%s\n' % line)                
        elif not continue_prompt:
            self.history_manager.input_hist_raw.append('\n')
        try:
            lineout = self.prefilter_manager.prefilter_lines(line,continue_prompt)
        except:
            # blanket except, in case a user-defined prefilter crashes, so it
            # can't take all of ipython with it.
            self.showtraceback()
            return ''
        else:
            return lineout


    def raw_input(self, prompt=''):
        """Write a prompt and read a line.

        The returned line does not include the trailing newline.
        When the user enters the EOF key sequence, EOFError is raised.

        Optional inputs:

          - prompt(''): a string to be printed to prompt the user.

          - continue_prompt(False): whether this line is the first one or a
          continuation in a sequence of inputs.
        """
        # Code run by the user may have modified the readline completer state.
        # We must ensure that our completer is back in place.

        if self.has_readline:
            self.set_readline_completer()
        
        try:
            line = raw_input_original(prompt).decode(self.stdin_encoding)
        except ValueError:
            warn("\n********\nYou or a %run:ed script called sys.stdin.close()"
                 " or sys.stdout.close()!\nExiting IPython!")
            self.ask_exit()
            return ""

        # Try to be reasonably smart about not re-indenting pasted input more
        # than necessary.  We do this by trimming out the auto-indent initial
        # spaces, if the user's actual input started itself with whitespace.
        if self.autoindent:
            if num_ini_spaces(line) > self.indent_current_nsp:
                line = line[self.indent_current_nsp:]
                self.indent_current_nsp = 0
        
        return line

    #-------------------------------------------------------------------------
    # Methods to support auto-editing of SyntaxErrors.
    #-------------------------------------------------------------------------

    def edit_syntax_error(self):
        """The bottom half of the syntax error handler called in the main loop.

        Loop until syntax error is fixed or user cancels.
        """

        while self.SyntaxTB.last_syntax_error:
            # copy and clear last_syntax_error
            err = self.SyntaxTB.clear_err_state()
            if not self._should_recompile(err):
                return
            try:
                # may set last_syntax_error again if a SyntaxError is raised
                self.safe_execfile(err.filename,self.user_ns)
            except:
                self.showtraceback()
            else:
                try:
                    f = file(err.filename)
                    try:
                        # This should be inside a display_trap block and I 
                        # think it is.
                        sys.displayhook(f.read())
                    finally:
                        f.close()
                except:
                    self.showtraceback()

    def _should_recompile(self,e):
        """Utility routine for edit_syntax_error"""

        if e.filename in ('<ipython console>','<input>','<string>',
                          '<console>','<BackgroundJob compilation>',
                          None):
                              
            return False
        try:
            if (self.autoedit_syntax and 
                not self.ask_yes_no('Return to editor to correct syntax error? '
                              '[Y/n] ','y')):
                return False
        except EOFError:
            return False

        def int0(x):
            try:
                return int(x)
            except TypeError:
                return 0
        # always pass integer line and offset values to editor hook
        try:
            self.hooks.fix_error_editor(e.filename,
                int0(e.lineno),int0(e.offset),e.msg)
        except TryNext:
            warn('Could not open editor')
            return False
        return True

    #-------------------------------------------------------------------------
    # Things related to GUI support and pylab
    #-------------------------------------------------------------------------

    def enable_pylab(self, gui=None):
        """Activate pylab support at runtime.

        This turns on support for matplotlib, preloads into the interactive
        namespace all of numpy and pylab, and configures IPython to correcdtly
        interact with the GUI event loop.  The GUI backend to be used can be
        optionally selected with the optional :param:`gui` argument.

        Parameters
        ----------
        gui : optional, string

          If given, dictates the choice of matplotlib GUI backend to use
          (should be one of IPython's supported backends, 'tk', 'qt', 'wx' or
          'gtk'), otherwise we use the default chosen by matplotlib (as
          dictated by the matplotlib build-time options plus the user's
          matplotlibrc configuration file).
        """
        # We want to prevent the loading of pylab to pollute the user's
        # namespace as shown by the %who* magics, so we execute the activation
        # code in an empty namespace, and we update *both* user_ns and
        # user_ns_hidden with this information.
        ns = {}
        gui = pylab_activate(ns, gui)
        self.user_ns.update(ns)
        self.user_ns_hidden.update(ns)
        # Now we must activate the gui pylab wants to use, and fix %run to take
        # plot updates into account
        enable_gui(gui)
        self.magic_run = self._pylab_magic_run

    #-------------------------------------------------------------------------
    # Things related to exiting
    #-------------------------------------------------------------------------

    def ask_exit(self):
        """ Ask the shell to exit. Can be overiden and used as a callback. """
        self.exit_now = True

    def exit(self):
        """Handle interactive exit.

        This method calls the ask_exit callback."""
        if self.confirm_exit:
            if self.ask_yes_no('Do you really want to exit ([y]/n)?','y'):
                self.ask_exit()
        else:
            self.ask_exit()
            
    #------------------------------------------------------------------------
    # Magic overrides
    #------------------------------------------------------------------------
    # Once the base class stops inheriting from magic, this code needs to be
    # moved into a separate machinery as well.  For now, at least isolate here
    # the magics which this class needs to implement differently from the base
    # class, or that are unique to it.

    def magic_autoindent(self, parameter_s = ''):
        """Toggle autoindent on/off (if available)."""

        self.shell.set_autoindent()
        print "Automatic indentation is:",['OFF','ON'][self.shell.autoindent]

    @testdec.skip_doctest
    def magic_cpaste(self, parameter_s=''):
        """Paste & execute a pre-formatted code block from clipboard.
        
        You must terminate the block with '--' (two minus-signs) alone on the
        line. You can also provide your own sentinel with '%paste -s %%' ('%%' 
        is the new sentinel for this operation)
        
        The block is dedented prior to execution to enable execution of method
        definitions. '>' and '+' characters at the beginning of a line are
        ignored, to allow pasting directly from e-mails, diff files and
        doctests (the '...' continuation prompt is also stripped).  The
        executed block is also assigned to variable named 'pasted_block' for
        later editing with '%edit pasted_block'.
        
        You can also pass a variable name as an argument, e.g. '%cpaste foo'.
        This assigns the pasted block to variable 'foo' as string, without 
        dedenting or executing it (preceding >>> and + is still stripped)
        
        '%cpaste -r' re-executes the block previously entered by cpaste.
        
        Do not be alarmed by garbled output on Windows (it's a readline bug). 
        Just press enter and type -- (and press enter again) and the block 
        will be what was just pasted.
        
        IPython statements (magics, shell escapes) are not supported (yet).

        See also
        --------
        paste: automatically pull code from clipboard.
        
        Examples
        --------
        ::
        
          In [8]: %cpaste
          Pasting code; enter '--' alone on the line to stop.
          :>>> a = ["world!", "Hello"]
          :>>> print " ".join(sorted(a))
          :--
          Hello world!
        """
        
        opts,args = self.parse_options(parameter_s,'rs:',mode='string')
        par = args.strip()
        if opts.has_key('r'):
            self._rerun_pasted()
            return
        
        sentinel = opts.get('s','--')

        block = self._strip_pasted_lines_for_code(
            self._get_pasted_lines(sentinel))

        self._execute_block(block, par)

    def magic_paste(self, parameter_s=''):
        """Paste & execute a pre-formatted code block from clipboard.
        
        The text is pulled directly from the clipboard without user
        intervention and printed back on the screen before execution (unless
        the -q flag is given to force quiet mode).

        The block is dedented prior to execution to enable execution of method
        definitions. '>' and '+' characters at the beginning of a line are
        ignored, to allow pasting directly from e-mails, diff files and
        doctests (the '...' continuation prompt is also stripped).  The
        executed block is also assigned to variable named 'pasted_block' for
        later editing with '%edit pasted_block'.
        
        You can also pass a variable name as an argument, e.g. '%paste foo'.
        This assigns the pasted block to variable 'foo' as string, without 
        dedenting or executing it (preceding >>> and + is still stripped)

        Options
        -------
        
          -r: re-executes the block previously entered by cpaste.

          -q: quiet mode: do not echo the pasted text back to the terminal.
        
        IPython statements (magics, shell escapes) are not supported (yet).

        See also
        --------
        cpaste: manually paste code into terminal until you mark its end.
        """
        opts,args = self.parse_options(parameter_s,'rq',mode='string')
        par = args.strip()
        if opts.has_key('r'):
            self._rerun_pasted()
            return

        text = self.shell.hooks.clipboard_get()
        block = self._strip_pasted_lines_for_code(text.splitlines())

        # By default, echo back to terminal unless quiet mode is requested
        if not opts.has_key('q'):
            write = self.shell.write
            write(self.shell.pycolorize(block))
            if not block.endswith('\n'):
                write('\n')
            write("## -- End pasted text --\n")
            
        self._execute_block(block, par)
Esempio n. 14
0
class Device(App):
    """Base class for MSMAccelerator devices. These are processes that request data
    from the server, and run on a ZMQ REQ port.

    Current subclasses include Simulator and Modeler.

    When the device boots up, it will send a message to the server with the
    msg_type 'register_{ClassName}', where ClassName is the name of the subclass
    of device that was intantiated. When it receives a return message, that
    the method on_startup_message() will be called.

    Note, if you want to interface directly with the ZMQ socket, it's just
    self.socket
    """

    name = 'device'
    path = 'msmaccelerator.core.device.Device'
    short_description = 'Base class for MSMAccelerator devices'
    long_description = '''Contains common code for processes that connect to
        the msmaccelerator server to request data and do stuff'''

    zmq_port = Int(12345,
                   config=True,
                   help='ZeroMQ port to connect to the server on')
    zmq_url = Unicode('127.0.0.1',
                      config=True,
                      help='URL to connect to server with')
    uuid = Bytes(help='Unique identifier for this device')

    def _uuid_default(self):
        return str(uuid.uuid4())

    aliases = dict(zmq_port='Device.zmq_port', zmq_url='Device.zmq_url')

    @property
    def zmq_connection_string(self):
        return 'tcp://%s:%s' % (self.zmq_url, self.zmq_port)

    def start(self):
        self.ctx = zmq.Context()
        self.socket = self.ctx.socket(zmq.REQ)
        # we're using the uuid to set the identity of the socket
        # AND we're going to put it explicitly inside of the header of
        # the messages we send. Within the actual bytes that go over the wire,
        # this means that the uuid will actually be printed twice, but
        # its not a big deal. It makes it easier for the server application
        # code to have the sender of each message, and for the message json
        # itself to be "complete", insted of having the sender in a separate
        # data structure.
        self.socket.setsockopt(zmq.IDENTITY, self.uuid)
        self.socket.connect(self.zmq_connection_string)

        # send the "here i am message" to the server, and receive a response
        # we're using a "robust" send/recv pattern, basically retrying the
        # request a fixed number of times if no response is heard from the
        # server
        msg = self.send_recv(msg_type='register_%s' % self.__class__.__name__)
        self.on_startup_message(msg)

    def on_startup_message(self, msg_type, msg):
        """This method is called when the device receives its startup message
        from the server
        """
        raise NotImplementedError(
            'This method should be overriden in a device subclass')

    def send_message(self, msg_type, content=None):
        """Send a message to the server asynchronously.

        Since we're using the request/reply pattern, after calling send
        you need to call recv to get the server's response. Consider instead
        using the send_recv method instead

        See Also
        --------
        send_recv
        """
        if content is None:
            content = {}
        self.socket.send_json(
            pack_message(msg_type=msg_type,
                         content=content,
                         sender_id=self.uuid))

    def recv_message(self):
        """Receive a message from the server.

        Note, this methos is not async -- it blocks the device until
        the server delivers a message
        """
        raw_msg = self.socket.recv()
        msg = yaml.load(raw_msg)
        return Message(msg)

    def send_recv(self, msg_type, content=None, timeout=10, retries=3):
        """Send a message to the server and receive a response

        This method inplementes the "Lazy-Pirate pattern" for
        relaible request/reply flows described in the ZeroMQ book.

        Parameters
        ----------
        msg_type : str
            The type of the message to send. This is an essential
            part of the MSMAccelerator messaging protocol.
        content : dict
            The contents of the message to send. This is an essential
            part of the MSMAccelerator messaging protocol.
        timeout : int, float, default=3
            The timeout, in seconds, to wait for a response from the
            server.
        retries : int, default=3
            If a response from the server is not received within `timeout`
            seconds, we'll retry sending our payload at most `retries`
            number of times. After that point, if no return message has
            been received, we'll throw an IOError.
        """
        timeout_ms = timeout * 1000

        poller = zmq.Poller()
        for i in range(retries):
            self.send_message(msg_type, content)
            poller.register(self.socket, zmq.POLLIN)
            if poller.poll(timeout_ms):
                return self.recv_message()
            else:
                self.log.error(
                    'No response received from server on'
                    'msg_type=%s. Retrying...', msg_type)
                poller.unregister(self.socket)
                self.socket.close()
                self.socket = self.ctx.socket(zmq.REQ)
                self.socket.connect(self.zmq_connection_string)

        raise IOError('Network timeout. Server is unresponsive.')
Esempio n. 15
0
class Client(HasTraits):
    """A semi-synchronous client to the IPython ZMQ cluster
    
    Parameters
    ----------
    
    url_or_file : bytes; zmq url or path to ipcontroller-client.json
        Connection information for the Hub's registration.  If a json connector
        file is given, then likely no further configuration is necessary.
        [Default: use profile]
    profile : bytes
        The name of the Cluster profile to be used to find connector information.
        [Default: 'default']
    context : zmq.Context
        Pass an existing zmq.Context instance, otherwise the client will create its own.
    username : bytes
        set username to be passed to the Session object
    debug : bool
        flag for lots of message printing for debug purposes
 
    #-------------- ssh related args ----------------
    # These are args for configuring the ssh tunnel to be used
    # credentials are used to forward connections over ssh to the Controller
    # Note that the ip given in `addr` needs to be relative to sshserver
    # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
    # and set sshserver as the same machine the Controller is on. However, 
    # the only requirement is that sshserver is able to see the Controller
    # (i.e. is within the same trusted network).
    
    sshserver : str
        A string of the form passed to ssh, i.e. 'server.tld' or '[email protected]:port'
        If keyfile or password is specified, and this is not, it will default to
        the ip given in addr.
    sshkey : str; path to public ssh key file
        This specifies a key to be used in ssh login, default None.
        Regular default ssh keys will be used without specifying this argument.
    password : str 
        Your ssh password to sshserver. Note that if this is left None,
        you will be prompted for it if passwordless key based login is unavailable.
    paramiko : bool
        flag for whether to use paramiko instead of shell ssh for tunneling.
        [default: True on win32, False else]
    
    ------- exec authentication args -------
    If even localhost is untrusted, you can have some protection against
    unauthorized execution by using a key.  Messages are still sent
    as cleartext, so if someone can snoop your loopback traffic this will
    not help against malicious attacks.
    
    exec_key : str
        an authentication key or file containing a key
        default: None
    
    
    Attributes
    ----------
    
    ids : list of int engine IDs
        requesting the ids attribute always synchronizes
        the registration state. To request ids without synchronization,
        use semi-private _ids attributes.
    
    history : list of msg_ids
        a list of msg_ids, keeping track of all the execution
        messages you have submitted in order.
    
    outstanding : set of msg_ids
        a set of msg_ids that have been submitted, but whose
        results have not yet been received.
    
    results : dict
        a dict of all our results, keyed by msg_id
    
    block : bool
        determines default behavior when block not specified
        in execution methods
    
    Methods
    -------
    
    spin
        flushes incoming results and registration state changes
        control methods spin, and requesting `ids` also ensures up to date
    
    wait
        wait on one or more msg_ids
    
    execution methods
        apply
        legacy: execute, run
    
    data movement
        push, pull, scatter, gather
    
    query methods
        queue_status, get_result, purge, result_status
    
    control methods
        abort, shutdown
    
    """

    block = Bool(False)
    outstanding = Set()
    results = Instance('collections.defaultdict', (dict, ))
    metadata = Instance('collections.defaultdict', (Metadata, ))
    history = List()
    debug = Bool(False)
    profile = CUnicode('default')

    _outstanding_dict = Instance('collections.defaultdict', (set, ))
    _ids = List()
    _connected = Bool(False)
    _ssh = Bool(False)
    _context = Instance('zmq.Context')
    _config = Dict()
    _engines = Instance(util.ReverseDict, (), {})
    # _hub_socket=Instance('zmq.Socket')
    _query_socket = Instance('zmq.Socket')
    _control_socket = Instance('zmq.Socket')
    _iopub_socket = Instance('zmq.Socket')
    _notification_socket = Instance('zmq.Socket')
    _mux_socket = Instance('zmq.Socket')
    _task_socket = Instance('zmq.Socket')
    _task_scheme = Str()
    _closed = False
    _ignored_control_replies = Int(0)
    _ignored_hub_replies = Int(0)

    def __init__(self,
                 url_or_file=None,
                 profile='default',
                 cluster_dir=None,
                 ipython_dir=None,
                 context=None,
                 username=None,
                 debug=False,
                 exec_key=None,
                 sshserver=None,
                 sshkey=None,
                 password=None,
                 paramiko=None,
                 timeout=10):
        super(Client, self).__init__(debug=debug, profile=profile)
        if context is None:
            context = zmq.Context.instance()
        self._context = context

        self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
        if self._cd is not None:
            if url_or_file is None:
                url_or_file = pjoin(self._cd.security_dir,
                                    'ipcontroller-client.json')
        assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
            " Please specify at least one of url_or_file or profile."

        try:
            util.validate_url(url_or_file)
        except AssertionError:
            if not os.path.exists(url_or_file):
                if self._cd:
                    url_or_file = os.path.join(self._cd.security_dir,
                                               url_or_file)
                assert os.path.exists(
                    url_or_file
                ), "Not a valid connection file or url: %r" % url_or_file
            with open(url_or_file) as f:
                cfg = json.loads(f.read())
        else:
            cfg = {'url': url_or_file}

        # sync defaults from args, json:
        if sshserver:
            cfg['ssh'] = sshserver
        if exec_key:
            cfg['exec_key'] = exec_key
        exec_key = cfg['exec_key']
        sshserver = cfg['ssh']
        url = cfg['url']
        location = cfg.setdefault('location', None)
        cfg['url'] = util.disambiguate_url(cfg['url'], location)
        url = cfg['url']

        self._config = cfg

        self._ssh = bool(sshserver or sshkey or password)
        if self._ssh and sshserver is None:
            # default to ssh via localhost
            sshserver = url.split('://')[1].split(':')[0]
        if self._ssh and password is None:
            if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
                password = False
            else:
                password = getpass("SSH Password for %s: " % sshserver)
        ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
        if exec_key is not None and os.path.isfile(exec_key):
            arg = 'keyfile'
        else:
            arg = 'key'
        key_arg = {arg: exec_key}
        if username is None:
            self.session = ss.StreamSession(**key_arg)
        else:
            self.session = ss.StreamSession(username, **key_arg)
        self._query_socket = self._context.socket(zmq.XREQ)
        self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
        if self._ssh:
            tunnel.tunnel_connection(self._query_socket, url, sshserver,
                                     **ssh_kwargs)
        else:
            self._query_socket.connect(url)

        self.session.debug = self.debug

        self._notification_handlers = {
            'registration_notification': self._register_engine,
            'unregistration_notification': self._unregister_engine,
            'shutdown_notification': lambda msg: self.close(),
        }
        self._queue_handlers = {
            'execute_reply': self._handle_execute_reply,
            'apply_reply': self._handle_apply_reply
        }
        self._connect(sshserver, ssh_kwargs, timeout)

    def __del__(self):
        """cleanup sockets, but _not_ context."""
        self.close()

    def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
        if ipython_dir is None:
            ipython_dir = get_ipython_dir()
        if cluster_dir is not None:
            try:
                self._cd = ClusterDir.find_cluster_dir(cluster_dir)
                return
            except ClusterDirError:
                pass
        elif profile is not None:
            try:
                self._cd = ClusterDir.find_cluster_dir_by_profile(
                    ipython_dir, profile)
                return
            except ClusterDirError:
                pass
        self._cd = None

    def _update_engines(self, engines):
        """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
        for k, v in engines.iteritems():
            eid = int(k)
            self._engines[eid] = bytes(v)  # force not unicode
            self._ids.append(eid)
        self._ids = sorted(self._ids)
        if sorted(self._engines.keys()) != range(len(self._engines)) and \
                        self._task_scheme == 'pure' and self._task_socket:
            self._stop_scheduling_tasks()

    def _stop_scheduling_tasks(self):
        """Stop scheduling tasks because an engine has been unregistered
        from a pure ZMQ scheduler.
        """
        self._task_socket.close()
        self._task_socket = None
        msg = "An engine has been unregistered, and we are using pure " +\
              "ZMQ task scheduling.  Task farming will be disabled."
        if self.outstanding:
            msg += " If you were running tasks when this happened, " +\
                   "some `outstanding` msg_ids may never resolve."
        warnings.warn(msg, RuntimeWarning)

    def _build_targets(self, targets):
        """Turn valid target IDs or 'all' into two lists:
        (int_ids, uuids).
        """
        if not self._ids:
            # flush notification socket if no engines yet, just in case
            if not self.ids:
                raise error.NoEnginesRegistered(
                    "Can't build targets without any engines")

        if targets is None:
            targets = self._ids
        elif isinstance(targets, str):
            if targets.lower() == 'all':
                targets = self._ids
            else:
                raise TypeError("%r not valid str target, must be 'all'" %
                                (targets))
        elif isinstance(targets, int):
            if targets < 0:
                targets = self.ids[targets]
            if targets not in self._ids:
                raise IndexError("No such engine: %i" % targets)
            targets = [targets]

        if isinstance(targets, slice):
            indices = range(len(self._ids))[targets]
            ids = self.ids
            targets = [ids[i] for i in indices]

        if not isinstance(targets, (tuple, list, xrange)):
            raise TypeError(
                "targets by int/slice/collection of ints only, not %s" %
                (type(targets)))

        return [self._engines[t] for t in targets], list(targets)

    def _connect(self, sshserver, ssh_kwargs, timeout):
        """setup all our socket connections to the cluster. This is called from
        __init__."""

        # Maybe allow reconnecting?
        if self._connected:
            return
        self._connected = True

        def connect_socket(s, url):
            url = util.disambiguate_url(url, self._config['location'])
            if self._ssh:
                return tunnel.tunnel_connection(s, url, sshserver,
                                                **ssh_kwargs)
            else:
                return s.connect(url)

        self.session.send(self._query_socket, 'connection_request')
        r, w, x = zmq.select([self._query_socket], [], [], timeout)
        if not r:
            raise error.TimeoutError("Hub connection request timed out")
        idents, msg = self.session.recv(self._query_socket, mode=0)
        if self.debug:
            pprint(msg)
        msg = ss.Message(msg)
        content = msg.content
        self._config['registration'] = dict(content)
        if content.status == 'ok':
            if content.mux:
                self._mux_socket = self._context.socket(zmq.XREQ)
                self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
                connect_socket(self._mux_socket, content.mux)
            if content.task:
                self._task_scheme, task_addr = content.task
                self._task_socket = self._context.socket(zmq.XREQ)
                self._task_socket.setsockopt(zmq.IDENTITY,
                                             self.session.session)
                connect_socket(self._task_socket, task_addr)
            if content.notification:
                self._notification_socket = self._context.socket(zmq.SUB)
                connect_socket(self._notification_socket, content.notification)
                self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
            # if content.query:
            #     self._query_socket = self._context.socket(zmq.XREQ)
            #     self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
            #     connect_socket(self._query_socket, content.query)
            if content.control:
                self._control_socket = self._context.socket(zmq.XREQ)
                self._control_socket.setsockopt(zmq.IDENTITY,
                                                self.session.session)
                connect_socket(self._control_socket, content.control)
            if content.iopub:
                self._iopub_socket = self._context.socket(zmq.SUB)
                self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
                self._iopub_socket.setsockopt(zmq.IDENTITY,
                                              self.session.session)
                connect_socket(self._iopub_socket, content.iopub)
            self._update_engines(dict(content.engines))
        else:
            self._connected = False
            raise Exception("Failed to connect!")

    #--------------------------------------------------------------------------
    # handlers and callbacks for incoming messages
    #--------------------------------------------------------------------------

    def _unwrap_exception(self, content):
        """unwrap exception, and remap engine_id to int."""
        e = error.unwrap_exception(content)
        # print e.traceback
        if e.engine_info:
            e_uuid = e.engine_info['engine_uuid']
            eid = self._engines[e_uuid]
            e.engine_info['engine_id'] = eid
        return e

    def _extract_metadata(self, header, parent, content):
        md = {
            'msg_id': parent['msg_id'],
            'received': datetime.now(),
            'engine_uuid': header.get('engine', None),
            'follow': parent.get('follow', []),
            'after': parent.get('after', []),
            'status': content['status'],
        }

        if md['engine_uuid'] is not None:
            md['engine_id'] = self._engines.get(md['engine_uuid'], None)

        if 'date' in parent:
            md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
        if 'started' in header:
            md['started'] = datetime.strptime(header['started'], util.ISO8601)
        if 'date' in header:
            md['completed'] = datetime.strptime(header['date'], util.ISO8601)
        return md

    def _register_engine(self, msg):
        """Register a new engine, and update our connection info."""
        content = msg['content']
        eid = content['id']
        d = {eid: content['queue']}
        self._update_engines(d)

    def _unregister_engine(self, msg):
        """Unregister an engine that has died."""
        content = msg['content']
        eid = int(content['id'])
        if eid in self._ids:
            self._ids.remove(eid)
            uuid = self._engines.pop(eid)

            self._handle_stranded_msgs(eid, uuid)

        if self._task_socket and self._task_scheme == 'pure':
            self._stop_scheduling_tasks()

    def _handle_stranded_msgs(self, eid, uuid):
        """Handle messages known to be on an engine when the engine unregisters.
        
        It is possible that this will fire prematurely - that is, an engine will
        go down after completing a result, and the client will be notified
        of the unregistration and later receive the successful result.
        """

        outstanding = self._outstanding_dict[uuid]

        for msg_id in list(outstanding):
            if msg_id in self.results:
                # we already
                continue
            try:
                raise error.EngineError(
                    "Engine %r died while running task %r" % (eid, msg_id))
            except:
                content = error.wrap_exception()
            # build a fake message:
            parent = {}
            header = {}
            parent['msg_id'] = msg_id
            header['engine'] = uuid
            header['date'] = datetime.now().strftime(util.ISO8601)
            msg = dict(parent_header=parent, header=header, content=content)
            self._handle_apply_reply(msg)

    def _handle_execute_reply(self, msg):
        """Save the reply to an execute_request into our results.
        
        execute messages are never actually used. apply is used instead.
        """

        parent = msg['parent_header']
        msg_id = parent['msg_id']
        if msg_id not in self.outstanding:
            if msg_id in self.history:
                print("got stale result: %s" % msg_id)
            else:
                print("got unknown result: %s" % msg_id)
        else:
            self.outstanding.remove(msg_id)
        self.results[msg_id] = self._unwrap_exception(msg['content'])

    def _handle_apply_reply(self, msg):
        """Save the reply to an apply_request into our results."""
        parent = msg['parent_header']
        msg_id = parent['msg_id']
        if msg_id not in self.outstanding:
            if msg_id in self.history:
                print("got stale result: %s" % msg_id)
                print self.results[msg_id]
                print msg
            else:
                print("got unknown result: %s" % msg_id)
        else:
            self.outstanding.remove(msg_id)
        content = msg['content']
        header = msg['header']

        # construct metadata:
        md = self.metadata[msg_id]
        md.update(self._extract_metadata(header, parent, content))
        # is this redundant?
        self.metadata[msg_id] = md

        e_outstanding = self._outstanding_dict[md['engine_uuid']]
        if msg_id in e_outstanding:
            e_outstanding.remove(msg_id)

        # construct result:
        if content['status'] == 'ok':
            self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
        elif content['status'] == 'aborted':
            self.results[msg_id] = error.TaskAborted(msg_id)
        elif content['status'] == 'resubmitted':
            # TODO: handle resubmission
            pass
        else:
            self.results[msg_id] = self._unwrap_exception(content)

    def _flush_notifications(self):
        """Flush notifications of engine registrations waiting
        in ZMQ queue."""
        msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
        while msg is not None:
            if self.debug:
                pprint(msg)
            msg = msg[-1]
            msg_type = msg['msg_type']
            handler = self._notification_handlers.get(msg_type, None)
            if handler is None:
                raise Exception("Unhandled message type: %s" % msg.msg_type)
            else:
                handler(msg)
            msg = self.session.recv(self._notification_socket,
                                    mode=zmq.NOBLOCK)

    def _flush_results(self, sock):
        """Flush task or queue results waiting in ZMQ queue."""
        msg = self.session.recv(sock, mode=zmq.NOBLOCK)
        while msg is not None:
            if self.debug:
                pprint(msg)
            msg = msg[-1]
            msg_type = msg['msg_type']
            handler = self._queue_handlers.get(msg_type, None)
            if handler is None:
                raise Exception("Unhandled message type: %s" % msg.msg_type)
            else:
                handler(msg)
            msg = self.session.recv(sock, mode=zmq.NOBLOCK)

    def _flush_control(self, sock):
        """Flush replies from the control channel waiting
        in the ZMQ queue.
        
        Currently: ignore them."""
        if self._ignored_control_replies <= 0:
            return
        msg = self.session.recv(sock, mode=zmq.NOBLOCK)
        while msg is not None:
            self._ignored_control_replies -= 1
            if self.debug:
                pprint(msg)
            msg = self.session.recv(sock, mode=zmq.NOBLOCK)

    def _flush_ignored_control(self):
        """flush ignored control replies"""
        while self._ignored_control_replies > 0:
            self.session.recv(self._control_socket)
            self._ignored_control_replies -= 1

    def _flush_ignored_hub_replies(self):
        msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
        while msg is not None:
            msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)

    def _flush_iopub(self, sock):
        """Flush replies from the iopub channel waiting
        in the ZMQ queue.
        """
        msg = self.session.recv(sock, mode=zmq.NOBLOCK)
        while msg is not None:
            if self.debug:
                pprint(msg)
            msg = msg[-1]
            parent = msg['parent_header']
            msg_id = parent['msg_id']
            content = msg['content']
            header = msg['header']
            msg_type = msg['msg_type']

            # init metadata:
            md = self.metadata[msg_id]

            if msg_type == 'stream':
                name = content['name']
                s = md[name] or ''
                md[name] = s + content['data']
            elif msg_type == 'pyerr':
                md.update({'pyerr': self._unwrap_exception(content)})
            elif msg_type == 'pyin':
                md.update({'pyin': content['code']})
            else:
                md.update({msg_type: content.get('data', '')})

            # reduntant?
            self.metadata[msg_id] = md

            msg = self.session.recv(sock, mode=zmq.NOBLOCK)

    #--------------------------------------------------------------------------
    # len, getitem
    #--------------------------------------------------------------------------

    def __len__(self):
        """len(client) returns # of engines."""
        return len(self.ids)

    def __getitem__(self, key):
        """index access returns DirectView multiplexer objects
        
        Must be int, slice, or list/tuple/xrange of ints"""
        if not isinstance(key, (int, slice, tuple, list, xrange)):
            raise TypeError("key by int/slice/iterable of ints only, not %s" %
                            (type(key)))
        else:
            return self.direct_view(key)

    #--------------------------------------------------------------------------
    # Begin public methods
    #--------------------------------------------------------------------------

    @property
    def ids(self):
        """Always up-to-date ids property."""
        self._flush_notifications()
        # always copy:
        return list(self._ids)

    def close(self):
        if self._closed:
            return
        snames = filter(lambda n: n.endswith('socket'), dir(self))
        for socket in map(lambda name: getattr(self, name), snames):
            if isinstance(socket, zmq.Socket) and not socket.closed:
                socket.close()
        self._closed = True

    def spin(self):
        """Flush any registration notifications and execution results
        waiting in the ZMQ queue.
        """
        if self._notification_socket:
            self._flush_notifications()
        if self._mux_socket:
            self._flush_results(self._mux_socket)
        if self._task_socket:
            self._flush_results(self._task_socket)
        if self._control_socket:
            self._flush_control(self._control_socket)
        if self._iopub_socket:
            self._flush_iopub(self._iopub_socket)
        if self._query_socket:
            self._flush_ignored_hub_replies()

    def wait(self, jobs=None, timeout=-1):
        """waits on one or more `jobs`, for up to `timeout` seconds.
        
        Parameters
        ----------
        
        jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
                ints are indices to self.history
                strs are msg_ids
                default: wait on all outstanding messages
        timeout : float
                a time in seconds, after which to give up.
                default is -1, which means no timeout
        
        Returns
        -------
        
        True : when all msg_ids are done
        False : timeout reached, some msg_ids still outstanding
        """
        tic = time.time()
        if jobs is None:
            theids = self.outstanding
        else:
            if isinstance(jobs, (int, str, AsyncResult)):
                jobs = [jobs]
            theids = set()
            for job in jobs:
                if isinstance(job, int):
                    # index access
                    job = self.history[job]
                elif isinstance(job, AsyncResult):
                    map(theids.add, job.msg_ids)
                    continue
                theids.add(job)
        if not theids.intersection(self.outstanding):
            return True
        self.spin()
        while theids.intersection(self.outstanding):
            if timeout >= 0 and (time.time() - tic) > timeout:
                break
            time.sleep(1e-3)
            self.spin()
        return len(theids.intersection(self.outstanding)) == 0

    #--------------------------------------------------------------------------
    # Control methods
    #--------------------------------------------------------------------------

    @spin_first
    def clear(self, targets=None, block=None):
        """Clear the namespace in target(s)."""
        block = self.block if block is None else block
        targets = self._build_targets(targets)[0]
        for t in targets:
            self.session.send(self._control_socket,
                              'clear_request',
                              content={},
                              ident=t)
        error = False
        if block:
            self._flush_ignored_control()
            for i in range(len(targets)):
                idents, msg = self.session.recv(self._control_socket, 0)
                if self.debug:
                    pprint(msg)
                if msg['content']['status'] != 'ok':
                    error = self._unwrap_exception(msg['content'])
        else:
            self._ignored_control_replies += len(targets)
        if error:
            raise error

    @spin_first
    def abort(self, jobs=None, targets=None, block=None):
        """Abort specific jobs from the execution queues of target(s).
        
        This is a mechanism to prevent jobs that have already been submitted
        from executing.
        
        Parameters
        ----------
        
        jobs : msg_id, list of msg_ids, or AsyncResult
            The jobs to be aborted
        
        
        """
        block = self.block if block is None else block
        targets = self._build_targets(targets)[0]
        msg_ids = []
        if isinstance(jobs, (basestring, AsyncResult)):
            jobs = [jobs]
        bad_ids = filter(
            lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
        if bad_ids:
            raise TypeError(
                "Invalid msg_id type %r, expected str or AsyncResult" %
                bad_ids[0])
        for j in jobs:
            if isinstance(j, AsyncResult):
                msg_ids.extend(j.msg_ids)
            else:
                msg_ids.append(j)
        content = dict(msg_ids=msg_ids)
        for t in targets:
            self.session.send(self._control_socket,
                              'abort_request',
                              content=content,
                              ident=t)
        error = False
        if block:
            self._flush_ignored_control()
            for i in range(len(targets)):
                idents, msg = self.session.recv(self._control_socket, 0)
                if self.debug:
                    pprint(msg)
                if msg['content']['status'] != 'ok':
                    error = self._unwrap_exception(msg['content'])
        else:
            self._ignored_control_replies += len(targets)
        if error:
            raise error

    @spin_first
    def shutdown(self, targets=None, restart=False, hub=False, block=None):
        """Terminates one or more engine processes, optionally including the hub."""
        block = self.block if block is None else block
        if hub:
            targets = 'all'
        targets = self._build_targets(targets)[0]
        for t in targets:
            self.session.send(self._control_socket,
                              'shutdown_request',
                              content={'restart': restart},
                              ident=t)
        error = False
        if block or hub:
            self._flush_ignored_control()
            for i in range(len(targets)):
                idents, msg = self.session.recv(self._control_socket, 0)
                if self.debug:
                    pprint(msg)
                if msg['content']['status'] != 'ok':
                    error = self._unwrap_exception(msg['content'])
        else:
            self._ignored_control_replies += len(targets)

        if hub:
            time.sleep(0.25)
            self.session.send(self._query_socket, 'shutdown_request')
            idents, msg = self.session.recv(self._query_socket, 0)
            if self.debug:
                pprint(msg)
            if msg['content']['status'] != 'ok':
                error = self._unwrap_exception(msg['content'])

        if error:
            raise error

    #--------------------------------------------------------------------------
    # Execution related methods
    #--------------------------------------------------------------------------

    def _maybe_raise(self, result):
        """wrapper for maybe raising an exception if apply failed."""
        if isinstance(result, error.RemoteError):
            raise result

        return result

    def send_apply_message(self,
                           socket,
                           f,
                           args=None,
                           kwargs=None,
                           subheader=None,
                           track=False,
                           ident=None):
        """construct and send an apply message via a socket.
        
        This is the principal method with which all engine execution is performed by views.
        """

        assert not self._closed, "cannot use me anymore, I'm closed!"
        # defaults:
        args = args if args is not None else []
        kwargs = kwargs if kwargs is not None else {}
        subheader = subheader if subheader is not None else {}

        # validate arguments
        if not callable(f):
            raise TypeError("f must be callable, not %s" % type(f))
        if not isinstance(args, (tuple, list)):
            raise TypeError("args must be tuple or list, not %s" % type(args))
        if not isinstance(kwargs, dict):
            raise TypeError("kwargs must be dict, not %s" % type(kwargs))
        if not isinstance(subheader, dict):
            raise TypeError("subheader must be dict, not %s" % type(subheader))

        bufs = util.pack_apply_message(f, args, kwargs)

        msg = self.session.send(socket,
                                "apply_request",
                                buffers=bufs,
                                ident=ident,
                                subheader=subheader,
                                track=track)

        msg_id = msg['msg_id']
        self.outstanding.add(msg_id)
        if ident:
            # possibly routed to a specific engine
            if isinstance(ident, list):
                ident = ident[-1]
            if ident in self._engines.values():
                # save for later, in case of engine death
                self._outstanding_dict[ident].add(msg_id)
        self.history.append(msg_id)
        self.metadata[msg_id]['submitted'] = datetime.now()

        return msg

    #--------------------------------------------------------------------------
    # construct a View object
    #--------------------------------------------------------------------------

    def load_balanced_view(self, targets=None):
        """construct a DirectView object.
        
        If no arguments are specified, create a LoadBalancedView
        using all engines.
        
        Parameters
        ----------
        
        targets: list,slice,int,etc. [default: use all engines]
            The subset of engines across which to load-balance
        """
        if targets is not None:
            targets = self._build_targets(targets)[1]
        return LoadBalancedView(client=self,
                                socket=self._task_socket,
                                targets=targets)

    def direct_view(self, targets='all'):
        """construct a DirectView object.
        
        If no targets are specified, create a DirectView
        using all engines.
        
        Parameters
        ----------
        
        targets: list,slice,int,etc. [default: use all engines]
            The engines to use for the View
        """
        single = isinstance(targets, int)
        targets = self._build_targets(targets)[1]
        if single:
            targets = targets[0]
        return DirectView(client=self,
                          socket=self._mux_socket,
                          targets=targets)

    #--------------------------------------------------------------------------
    # Query methods
    #--------------------------------------------------------------------------

    @spin_first
    def get_result(self, indices_or_msg_ids=None, block=None):
        """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
        
        If the client already has the results, no request to the Hub will be made.
        
        This is a convenient way to construct AsyncResult objects, which are wrappers
        that include metadata about execution, and allow for awaiting results that
        were not submitted by this Client.
        
        It can also be a convenient way to retrieve the metadata associated with
        blocking execution, since it always retrieves
        
        Examples
        --------
        ::
        
            In [10]: r = client.apply()
        
        Parameters
        ----------
        
        indices_or_msg_ids : integer history index, str msg_id, or list of either
            The indices or msg_ids of indices to be retrieved
        
        block : bool
            Whether to wait for the result to be done
        
        Returns
        -------
        
        AsyncResult
            A single AsyncResult object will always be returned.
        
        AsyncHubResult
            A subclass of AsyncResult that retrieves results from the Hub
        
        """
        block = self.block if block is None else block
        if indices_or_msg_ids is None:
            indices_or_msg_ids = -1

        if not isinstance(indices_or_msg_ids, (list, tuple)):
            indices_or_msg_ids = [indices_or_msg_ids]

        theids = []
        for id in indices_or_msg_ids:
            if isinstance(id, int):
                id = self.history[id]
            if not isinstance(id, str):
                raise TypeError("indices must be str or int, not %r" % id)
            theids.append(id)

        local_ids = filter(
            lambda msg_id: msg_id in self.history or msg_id in self.results,
            theids)
        remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)

        if remote_ids:
            ar = AsyncHubResult(self, msg_ids=theids)
        else:
            ar = AsyncResult(self, msg_ids=theids)

        if block:
            ar.wait()

        return ar

    @spin_first
    def result_status(self, msg_ids, status_only=True):
        """Check on the status of the result(s) of the apply request with `msg_ids`.
        
        If status_only is False, then the actual results will be retrieved, else
        only the status of the results will be checked.
        
        Parameters
        ----------
        
        msg_ids : list of msg_ids
            if int:
                Passed as index to self.history for convenience.
        status_only : bool (default: True)
            if False:
                Retrieve the actual results of completed tasks.
        
        Returns
        -------
        
        results : dict
            There will always be the keys 'pending' and 'completed', which will
            be lists of msg_ids that are incomplete or complete. If `status_only`
            is False, then completed results will be keyed by their `msg_id`.
        """
        if not isinstance(msg_ids, (list, tuple)):
            msg_ids = [msg_ids]

        theids = []
        for msg_id in msg_ids:
            if isinstance(msg_id, int):
                msg_id = self.history[msg_id]
            if not isinstance(msg_id, basestring):
                raise TypeError("msg_ids must be str, not %r" % msg_id)
            theids.append(msg_id)

        completed = []
        local_results = {}

        # comment this block out to temporarily disable local shortcut:
        for msg_id in theids:
            if msg_id in self.results:
                completed.append(msg_id)
                local_results[msg_id] = self.results[msg_id]
                theids.remove(msg_id)

        if theids:  # some not locally cached
            content = dict(msg_ids=theids, status_only=status_only)
            msg = self.session.send(self._query_socket,
                                    "result_request",
                                    content=content)
            zmq.select([self._query_socket], [], [])
            idents, msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
            if self.debug:
                pprint(msg)
            content = msg['content']
            if content['status'] != 'ok':
                raise self._unwrap_exception(content)
            buffers = msg['buffers']
        else:
            content = dict(completed=[], pending=[])

        content['completed'].extend(completed)

        if status_only:
            return content

        failures = []
        # load cached results into result:
        content.update(local_results)
        # update cache with results:
        for msg_id in sorted(theids):
            if msg_id in content['completed']:
                rec = content[msg_id]
                parent = rec['header']
                header = rec['result_header']
                rcontent = rec['result_content']
                iodict = rec['io']
                if isinstance(rcontent, str):
                    rcontent = self.session.unpack(rcontent)

                md = self.metadata[msg_id]
                md.update(self._extract_metadata(header, parent, rcontent))
                md.update(iodict)

                if rcontent['status'] == 'ok':
                    res, buffers = util.unserialize_object(buffers)
                else:
                    print rcontent
                    res = self._unwrap_exception(rcontent)
                    failures.append(res)

                self.results[msg_id] = res
                content[msg_id] = res

        if len(theids) == 1 and failures:
            raise failures[0]

        error.collect_exceptions(failures, "result_status")
        return content

    @spin_first
    def queue_status(self, targets='all', verbose=False):
        """Fetch the status of engine queues.
        
        Parameters
        ----------
        
        targets : int/str/list of ints/strs
                the engines whose states are to be queried.
                default : all
        verbose : bool
                Whether to return lengths only, or lists of ids for each element
        """
        engine_ids = self._build_targets(targets)[1]
        content = dict(targets=engine_ids, verbose=verbose)
        self.session.send(self._query_socket, "queue_request", content=content)
        idents, msg = self.session.recv(self._query_socket, 0)
        if self.debug:
            pprint(msg)
        content = msg['content']
        status = content.pop('status')
        if status != 'ok':
            raise self._unwrap_exception(content)
        content = util.rekey(content)
        if isinstance(targets, int):
            return content[targets]
        else:
            return content

    @spin_first
    def purge_results(self, jobs=[], targets=[]):
        """Tell the Hub to forget results.
        
        Individual results can be purged by msg_id, or the entire
        history of specific targets can be purged.
        
        Parameters
        ----------
        
        jobs : str or list of str or AsyncResult objects
                the msg_ids whose results should be forgotten.
        targets : int/str/list of ints/strs
                The targets, by uuid or int_id, whose entire history is to be purged.
                Use `targets='all'` to scrub everything from the Hub's memory.
                
                default : None
        """
        if not targets and not jobs:
            raise ValueError(
                "Must specify at least one of `targets` and `jobs`")
        if targets:
            targets = self._build_targets(targets)[1]

        # construct msg_ids from jobs
        msg_ids = []
        if isinstance(jobs, (basestring, AsyncResult)):
            jobs = [jobs]
        bad_ids = filter(
            lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
        if bad_ids:
            raise TypeError(
                "Invalid msg_id type %r, expected str or AsyncResult" %
                bad_ids[0])
        for j in jobs:
            if isinstance(j, AsyncResult):
                msg_ids.extend(j.msg_ids)
            else:
                msg_ids.append(j)

        content = dict(targets=targets, msg_ids=msg_ids)
        self.session.send(self._query_socket, "purge_request", content=content)
        idents, msg = self.session.recv(self._query_socket, 0)
        if self.debug:
            pprint(msg)
        content = msg['content']
        if content['status'] != 'ok':
            raise self._unwrap_exception(content)

    @spin_first
    def hub_history(self):
        """Get the Hub's history
        
        Just like the Client, the Hub has a history, which is a list of msg_ids.
        This will contain the history of all clients, and, depending on configuration,
        may contain history across multiple cluster sessions.
        
        Any msg_id returned here is a valid argument to `get_result`.
        
        Returns
        -------
        
        msg_ids : list of strs
                list of all msg_ids, ordered by task submission time.
        """

        self.session.send(self._query_socket, "history_request", content={})
        idents, msg = self.session.recv(self._query_socket, 0)

        if self.debug:
            pprint(msg)
        content = msg['content']
        if content['status'] != 'ok':
            raise self._unwrap_exception(content)
        else:
            return content['history']

    @spin_first
    def db_query(self, query, keys=None):
        """Query the Hub's TaskRecord database
        
        This will return a list of task record dicts that match `query`
        
        Parameters
        ----------
        
        query : mongodb query dict
            The search dict. See mongodb query docs for details.
        keys : list of strs [optional]
            THe subset of keys to be returned.  The default is to fetch everything.
            'msg_id' will *always* be included.
        """
        content = dict(query=query, keys=keys)
        self.session.send(self._query_socket, "db_request", content=content)
        idents, msg = self.session.recv(self._query_socket, 0)
        if self.debug:
            pprint(msg)
        content = msg['content']
        if content['status'] != 'ok':
            raise self._unwrap_exception(content)

        records = content['records']
        buffer_lens = content['buffer_lens']
        result_buffer_lens = content['result_buffer_lens']
        buffers = msg['buffers']
        has_bufs = buffer_lens is not None
        has_rbufs = result_buffer_lens is not None
        for i, rec in enumerate(records):
            # relink buffers
            if has_bufs:
                blen = buffer_lens[i]
                rec['buffers'], buffers = buffers[:blen], buffers[blen:]
            if has_rbufs:
                blen = result_buffer_lens[i]
                rec['result_buffers'], buffers = buffers[:blen], buffers[blen:]
            # turn timestamps back into times
            for key in 'submitted started completed resubmitted'.split():
                maybedate = rec.get(key, None)
                if maybedate and util.ISO8601_RE.match(maybedate):
                    rec[key] = datetime.strptime(maybedate, util.ISO8601)

        return records
Esempio n. 16
0
class KernelApp(BaseIPythonApplication):
    name = 'pykernel'
    aliases = Dict(kernel_aliases)
    flags = Dict(kernel_flags)
    classes = [Session]
    # the kernel class, as an importstring
    kernel_class = DottedObjectName('IPython.zmq.pykernel.Kernel')
    kernel = Any()
    poller = Any(
    )  # don't restrict this even though current pollers are all Threads
    heartbeat = Instance(Heartbeat)
    session = Instance('IPython.zmq.session.Session')
    ports = Dict()

    # inherit config file name from parent:
    parent_appname = Unicode(config=True)

    def _parent_appname_changed(self, name, old, new):
        if self.config_file_specified:
            # it was manually specified, ignore
            return
        self.config_file_name = new.replace('-', '_') + u'_config.py'
        # don't let this count as specifying the config file
        self.config_file_specified = False

    # connection info:
    ip = Unicode(
        LOCALHOST,
        config=True,
        help="Set the IP or interface on which the kernel will listen.")
    hb_port = Int(0,
                  config=True,
                  help="set the heartbeat port [default: random]")
    shell_port = Int(0,
                     config=True,
                     help="set the shell (XREP) port [default: random]")
    iopub_port = Int(0,
                     config=True,
                     help="set the iopub (PUB) port [default: random]")
    stdin_port = Int(0,
                     config=True,
                     help="set the stdin (XREQ) port [default: random]")

    # streams, etc.
    no_stdout = Bool(False,
                     config=True,
                     help="redirect stdout to the null device")
    no_stderr = Bool(False,
                     config=True,
                     help="redirect stderr to the null device")
    outstream_class = DottedObjectName(
        'IPython.zmq.iostream.OutStream',
        config=True,
        help="The importstring for the OutStream factory")
    displayhook_class = DottedObjectName(
        'IPython.zmq.displayhook.ZMQDisplayHook',
        config=True,
        help="The importstring for the DisplayHook factory")

    # polling
    parent = Int(
        0,
        config=True,
        help="""kill this process if its parent dies.  On Windows, the argument
        specifies the HANDLE of the parent process, otherwise it is simply boolean.
        """)
    interrupt = Int(0,
                    config=True,
                    help="""ONLY USED ON WINDOWS
        Interrupt this process when the parent is signalled.
        """)

    def init_crash_handler(self):
        # Install minimal exception handling
        sys.excepthook = FormattedTB(mode='Verbose',
                                     color_scheme='NoColor',
                                     ostream=sys.__stdout__)

    def init_poller(self):
        if sys.platform == 'win32':
            if self.interrupt or self.parent:
                self.poller = ParentPollerWindows(self.interrupt, self.parent)
        elif self.parent:
            self.poller = ParentPollerUnix()

    def _bind_socket(self, s, port):
        iface = 'tcp://%s' % self.ip
        if port <= 0:
            port = s.bind_to_random_port(iface)
        else:
            s.bind(iface + ':%i' % port)
        return port

    def init_sockets(self):
        # Create a context, a session, and the kernel sockets.
        self.log.info("Starting the kernel at pid:", os.getpid())
        context = zmq.Context.instance()
        # Uncomment this to try closing the context.
        # atexit.register(context.term)

        self.shell_socket = context.socket(zmq.XREP)
        self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
        self.log.debug("shell XREP Channel on port: %i" % self.shell_port)

        self.iopub_socket = context.socket(zmq.PUB)
        self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port)
        self.log.debug("iopub PUB Channel on port: %i" % self.iopub_port)

        self.stdin_socket = context.socket(zmq.XREQ)
        self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port)
        self.log.debug("stdin XREQ Channel on port: %i" % self.stdin_port)

        self.heartbeat = Heartbeat(context, (self.ip, self.hb_port))
        self.hb_port = self.heartbeat.port
        self.log.debug("Heartbeat REP Channel on port: %i" % self.hb_port)

        # Helper to make it easier to connect to an existing kernel, until we have
        # single-port connection negotiation fully implemented.
        # set log-level to critical, to make sure it is output
        self.log.critical("To connect another client to this kernel, use:")
        self.log.critical(
            "--existing shell={0} iopub={1} stdin={2} hb={3}".format(
                self.shell_port, self.iopub_port, self.stdin_port,
                self.hb_port))

        self.ports = dict(shell=self.shell_port,
                          iopub=self.iopub_port,
                          stdin=self.stdin_port,
                          hb=self.hb_port)

    def init_session(self):
        """create our session object"""
        self.session = Session(config=self.config, username=u'kernel')

    def init_blackhole(self):
        """redirects stdout/stderr to devnull if necessary"""
        if self.no_stdout or self.no_stderr:
            blackhole = file(os.devnull, 'w')
            if self.no_stdout:
                sys.stdout = sys.__stdout__ = blackhole
            if self.no_stderr:
                sys.stderr = sys.__stderr__ = blackhole

    def init_io(self):
        """Redirect input streams and set a display hook."""
        if self.outstream_class:
            outstream_factory = import_item(str(self.outstream_class))
            sys.stdout = outstream_factory(self.session, self.iopub_socket,
                                           u'stdout')
            sys.stderr = outstream_factory(self.session, self.iopub_socket,
                                           u'stderr')
        if self.displayhook_class:
            displayhook_factory = import_item(str(self.displayhook_class))
            sys.displayhook = displayhook_factory(self.session,
                                                  self.iopub_socket)

    def init_kernel(self):
        """Create the Kernel object itself"""
        kernel_factory = import_item(str(self.kernel_class))
        self.kernel = kernel_factory(config=self.config,
                                     session=self.session,
                                     shell_socket=self.shell_socket,
                                     iopub_socket=self.iopub_socket,
                                     stdin_socket=self.stdin_socket,
                                     log=self.log)
        self.kernel.record_ports(self.ports)

    def initialize(self, argv=None):
        super(KernelApp, self).initialize(argv)
        self.init_blackhole()
        self.init_session()
        self.init_poller()
        self.init_sockets()
        self.init_io()
        self.init_kernel()

    def start(self):
        self.heartbeat.start()
        if self.poller is not None:
            self.poller.start()
        try:
            self.kernel.start()
        except KeyboardInterrupt:
            pass
Esempio n. 17
0
class IPythonConsoleApp(Configurable):
    name = 'ipython-console-mixin'
    default_config_file_name='ipython_config.py'

    description = """
        The IPython Mixin Console.
        
        This class contains the common portions of console client (QtConsole,
        ZMQ-based terminal console, etc).  It is not a full console, in that
        launched terminal subprocesses will not be able to accept input.
        
        The Console using this mixing supports various extra features beyond
        the single-process Terminal IPython shell, such as connecting to
        existing kernel, via:
        
            ipython <appname> --existing
        
        as well as tunnel via SSH
        
    """

    classes = classes
    flags = Dict(flags)
    aliases = Dict(aliases)
    kernel_manager_class = KernelManager
    kernel_client_class = BlockingKernelClient

    kernel_argv = List(Unicode)
    # frontend flags&aliases to be stripped when building kernel_argv
    frontend_flags = Any(app_flags)
    frontend_aliases = Any(app_aliases)

    # create requested profiles by default, if they don't exist:
    auto_create = CBool(True)
    # connection info:
    
    sshserver = Unicode('', config=True,
        help="""The SSH server to use to connect to the kernel.""")
    sshkey = Unicode('', config=True,
        help="""Path to the ssh key to use for logging in to the ssh server.""")
    
    hb_port = Int(0, config=True,
        help="set the heartbeat port [default: random]")
    shell_port = Int(0, config=True,
        help="set the shell (ROUTER) port [default: random]")
    iopub_port = Int(0, config=True,
        help="set the iopub (PUB) port [default: random]")
    stdin_port = Int(0, config=True,
        help="set the stdin (DEALER) port [default: random]")
    connection_file = Unicode('', config=True,
        help="""JSON file in which to store connection info [default: kernel-<pid>.json]

        This file will contain the IP, ports, and authentication key needed to connect
        clients to this kernel. By default, this file will be created in the security-dir
        of the current profile, but can be specified by absolute path.
        """)
    def _connection_file_default(self):
        return 'kernel-%i.json' % os.getpid()

    existing = CUnicode('', config=True,
        help="""Connect to an already running kernel""")

    confirm_exit = CBool(True, config=True,
        help="""
        Set to display confirmation dialog on exit. You can always use 'exit' or 'quit',
        to force a direct exit without any confirmation.""",
    )


    def build_kernel_argv(self, argv=None):
        """build argv to be passed to kernel subprocess"""
        if argv is None:
            argv = sys.argv[1:]
        self.kernel_argv = swallow_argv(argv, self.frontend_aliases, self.frontend_flags)
        # kernel should inherit default config file from frontend
        self.kernel_argv.append("--IPKernelApp.parent_appname='%s'" % self.name)
    
    def init_connection_file(self):
        """find the connection file, and load the info if found.
        
        The current working directory and the current profile's security
        directory will be searched for the file if it is not given by
        absolute path.
        
        When attempting to connect to an existing kernel and the `--existing`
        argument does not match an existing file, it will be interpreted as a
        fileglob, and the matching file in the current profile's security dir
        with the latest access time will be used.
        
        After this method is called, self.connection_file contains the *full path*
        to the connection file, never just its name.
        """
        if self.existing:
            try:
                cf = find_connection_file(self.existing)
            except Exception:
                self.log.critical("Could not find existing kernel connection file %s", self.existing)
                self.exit(1)
            self.log.info("Connecting to existing kernel: %s" % cf)
            self.connection_file = cf
        else:
            # not existing, check if we are going to write the file
            # and ensure that self.connection_file is a full path, not just the shortname
            try:
                cf = find_connection_file(self.connection_file)
            except Exception:
                # file might not exist
                if self.connection_file == os.path.basename(self.connection_file):
                    # just shortname, put it in security dir
                    cf = os.path.join(self.profile_dir.security_dir, self.connection_file)
                else:
                    cf = self.connection_file
                self.connection_file = cf
        
        # should load_connection_file only be used for existing?
        # as it is now, this allows reusing ports if an existing
        # file is requested
        try:
            self.load_connection_file()
        except Exception:
            self.log.error("Failed to load connection file: %r", self.connection_file, exc_info=True)
            self.exit(1)
    
    def load_connection_file(self):
        """load ip/port/hmac config from JSON connection file"""
        # this is identical to IPKernelApp.load_connection_file
        # perhaps it can be centralized somewhere?
        try:
            fname = filefind(self.connection_file, ['.', self.profile_dir.security_dir])
        except IOError:
            self.log.debug("Connection File not found: %s", self.connection_file)
            return
        self.log.debug(u"Loading connection file %s", fname)
        with open(fname) as f:
            cfg = json.load(f)
        
        self.config.KernelManager.transport = cfg.get('transport', 'tcp')
        self.config.KernelManager.ip = cfg.get('ip', LOCALHOST)
        
        for channel in ('hb', 'shell', 'iopub', 'stdin'):
            name = channel + '_port'
            if getattr(self, name) == 0 and name in cfg:
                # not overridden by config or cl_args
                setattr(self, name, cfg[name])
        if 'key' in cfg:
            self.config.Session.key = str_to_bytes(cfg['key'])
    
    def init_ssh(self):
        """set up ssh tunnels, if needed."""
        if not self.existing or (not self.sshserver and not self.sshkey):
            return
        
        self.load_connection_file()
        
        transport = self.config.KernelManager.transport
        ip = self.config.KernelManager.ip
        
        if transport != 'tcp':
            self.log.error("Can only use ssh tunnels with TCP sockets, not %s", transport)
            sys.exit(-1)
        
        if self.sshkey and not self.sshserver:
            # specifying just the key implies that we are connecting directly
            self.sshserver = ip
            ip = LOCALHOST
        
        # build connection dict for tunnels:
        info = dict(ip=ip,
                    shell_port=self.shell_port,
                    iopub_port=self.iopub_port,
                    stdin_port=self.stdin_port,
                    hb_port=self.hb_port
        )
        
        self.log.info("Forwarding connections to %s via %s"%(ip, self.sshserver))
        
        # tunnels return a new set of ports, which will be on localhost:
        self.config.KernelManager.ip = LOCALHOST
        try:
            newports = tunnel_to_kernel(info, self.sshserver, self.sshkey)
        except:
            # even catch KeyboardInterrupt
            self.log.error("Could not setup tunnels", exc_info=True)
            self.exit(1)
        
        self.shell_port, self.iopub_port, self.stdin_port, self.hb_port = newports
        
        cf = self.connection_file
        base,ext = os.path.splitext(cf)
        base = os.path.basename(base)
        self.connection_file = os.path.basename(base)+'-ssh'+ext
        self.log.critical("To connect another client via this tunnel, use:")
        self.log.critical("--existing %s" % self.connection_file)
    
    def _new_connection_file(self):
        cf = ''
        while not cf:
            # we don't need a 128b id to distinguish kernels, use more readable
            # 48b node segment (12 hex chars).  Users running more than 32k simultaneous
            # kernels can subclass.
            ident = str(uuid.uuid4()).split('-')[-1]
            cf = os.path.join(self.profile_dir.security_dir, 'kernel-%s.json' % ident)
            # only keep if it's actually new.  Protect against unlikely collision
            # in 48b random search space
            cf = cf if not os.path.exists(cf) else ''
        return cf

    def init_kernel_manager(self):
        # Don't let Qt or ZMQ swallow KeyboardInterupts.
        if self.existing:
            self.kernel_manager = None
            return
        signal.signal(signal.SIGINT, signal.SIG_DFL)

        # Create a KernelManager and start a kernel.
        self.kernel_manager = self.kernel_manager_class(
                                shell_port=self.shell_port,
                                iopub_port=self.iopub_port,
                                stdin_port=self.stdin_port,
                                hb_port=self.hb_port,
                                connection_file=self.connection_file,
                                config=self.config,
        )
        self.kernel_manager.client_factory = self.kernel_client_class
        self.kernel_manager.start_kernel(extra_arguments=self.kernel_argv)
        atexit.register(self.kernel_manager.cleanup_ipc_files)

        if self.sshserver:
            # ssh, write new connection file
            self.kernel_manager.write_connection_file()

        # in case KM defaults / ssh writing changes things:
        km = self.kernel_manager
        self.shell_port=km.shell_port
        self.iopub_port=km.iopub_port
        self.stdin_port=km.stdin_port
        self.hb_port=km.hb_port
        self.connection_file = km.connection_file

        atexit.register(self.kernel_manager.cleanup_connection_file)

    def init_kernel_client(self):
        if self.kernel_manager is not None:
            self.kernel_client = self.kernel_manager.client()
        else:
            self.kernel_client = self.kernel_client_class(
                                shell_port=self.shell_port,
                                iopub_port=self.iopub_port,
                                stdin_port=self.stdin_port,
                                hb_port=self.hb_port,
                                connection_file=self.connection_file,
                                config=self.config,
            )

        self.kernel_client.start_channels()



    def initialize(self, argv=None):
        """
        Classes which mix this class in should call:
               IPythonConsoleApp.initialize(self,argv)
        """
        self.init_connection_file()
        default_secure(self.config)
        self.init_ssh()
        self.init_kernel_manager()
        self.init_kernel_client()
Esempio n. 18
0
 class A(HasTraits):
     i = Int()
Esempio n. 19
0
 class A(HasTraits):
     x = Int(10)
     def _x_default(self):
         return 11
Esempio n. 20
0
        class A(HasTraits):
            i = Int(0)

            def __init__(self, i):
                super(A, self).__init__()
                self.i = i
Esempio n. 21
0
 class B(HasTraits):
     b = Int()
Esempio n. 22
0
class HistoryTrim(BaseIPythonApplication):
    description = trim_hist_help

    backup = Bool(False,
                  config=True,
                  help="Keep the old history file as history.sqlite.<N>")

    keep = Int(1000,
               config=True,
               help="Number of recent lines to keep in the database.")

    flags = Dict(
        dict(backup=({
            'HistoryTrim': {
                'backup': True
            }
        }, "Set Application.log_level to 0, maximizing log output.")))

    def start(self):
        profile_dir = self.profile_dir.location
        hist_file = os.path.join(profile_dir, 'history.sqlite')
        con = sqlite3.connect(hist_file)

        # Grab the recent history from the current database.
        inputs = list(
            con.execute(
                'SELECT session, line, source, source_raw FROM '
                'history ORDER BY session DESC, line DESC LIMIT ?',
                (self.keep + 1, )))
        if len(inputs) <= self.keep:
            print(
                "There are already at most %d entries in the history database."
                % self.keep)
            print("Not doing anything.")
            return

        print("Trimming history to the most recent %d entries." % self.keep)

        inputs.pop()  # Remove the extra element we got to check the length.
        inputs.reverse()
        first_session = inputs[0][0]
        outputs = list(
            con.execute(
                'SELECT session, line, output FROM '
                'output_history WHERE session >= ?', (first_session, )))
        sessions = list(
            con.execute(
                'SELECT session, start, end, num_cmds, remark FROM '
                'sessions WHERE session >= ?', (first_session, )))
        con.close()

        # Create the new history database.
        new_hist_file = os.path.join(profile_dir, 'history.sqlite.new')
        i = 0
        while os.path.exists(new_hist_file):
            # Make sure we don't interfere with an existing file.
            i += 1
            new_hist_file = os.path.join(profile_dir,
                                         'history.sqlite.new' + str(i))
        new_db = sqlite3.connect(new_hist_file)
        new_db.execute("""CREATE TABLE IF NOT EXISTS sessions (session integer
                            primary key autoincrement, start timestamp,
                            end timestamp, num_cmds integer, remark text)""")
        new_db.execute("""CREATE TABLE IF NOT EXISTS history
                        (session integer, line integer, source text, source_raw text,
                        PRIMARY KEY (session, line))""")
        new_db.execute("""CREATE TABLE IF NOT EXISTS output_history
                        (session integer, line integer, output text,
                        PRIMARY KEY (session, line))""")
        new_db.commit()

        with new_db:
            # Add the recent history into the new database.
            new_db.executemany('insert into sessions values (?,?,?,?,?)',
                               sessions)
            new_db.executemany('insert into history values (?,?,?,?)', inputs)
            new_db.executemany('insert into output_history values (?,?,?)',
                               outputs)
        new_db.close()

        if self.backup:
            i = 1
            backup_hist_file = os.path.join(profile_dir,
                                            'history.sqlite.old.%d' % i)
            while os.path.exists(backup_hist_file):
                i += 1
                backup_hist_file = os.path.join(profile_dir,
                                                'history.sqlite.old.%d' % i)
            os.rename(hist_file, backup_hist_file)
            print("Backed up longer history file to", backup_hist_file)
        else:
            os.remove(hist_file)

        os.rename(new_hist_file, hist_file)
Esempio n. 23
0
 class A(HasTraits):
     i = Int(config_key='MY_VALUE')
Esempio n. 24
0
class IPClusterEngines(BaseParallelApplication):

    name = u'ipcluster'
    description = engines_help
    examples = _engines_examples
    usage = None
    config_file_name = Unicode(default_config_file_name)
    default_log_level = logging.INFO
    classes = List()

    def _classes_default(self):
        from IPython.parallel.apps import launcher
        launchers = launcher.all_launchers
        eslaunchers = [l for l in launchers if 'EngineSet' in l.__name__]
        return [ProfileDir] + eslaunchers

    n = Int(
        num_cpus(),
        config=True,
        help=
        """The number of engines to start. The default is to use one for each
        CPU on your machine""")

    engine_launcher_class = DottedObjectName(
        'LocalEngineSetLauncher',
        config=True,
        help="""The class for launching a set of Engines. Change this value
        to use various batch systems to launch your engines, such as PBS,SGE,MPIExec,etc.
        Each launcher class has its own set of configuration options, for making sure
        it will work in your environment.
        
        You can also write your own launcher, and specify it's absolute import path,
        as in 'mymodule.launcher.FTLEnginesLauncher`.
        
        Examples include:
        
            LocalEngineSetLauncher : start engines locally as subprocesses [default]
            MPIExecEngineSetLauncher : use mpiexec to launch in an MPI environment
            PBSEngineSetLauncher : use PBS (qsub) to submit engines to a batch queue
            SGEEngineSetLauncher : use SGE (qsub) to submit engines to a batch queue
            SSHEngineSetLauncher : use SSH to start the controller
                                Note that SSH does *not* move the connection files
                                around, so you will likely have to do this manually
                                unless the machines are on a shared file system.
            WindowsHPCEngineSetLauncher : use Windows HPC
        """)
    daemonize = Bool(
        False,
        config=True,
        help="""Daemonize the ipcluster program. This implies --log-to-file.
        Not available on Windows.
        """)

    def _daemonize_changed(self, name, old, new):
        if new:
            self.log_to_file = True

    aliases = Dict(engine_aliases)
    flags = Dict(engine_flags)
    _stopping = False

    def initialize(self, argv=None):
        super(IPClusterEngines, self).initialize(argv)
        self.init_signal()
        self.init_launchers()

    def init_launchers(self):
        self.engine_launcher = self.build_launcher(self.engine_launcher_class)
        self.engine_launcher.on_stop(lambda r: self.loop.stop())

    def init_signal(self):
        # Setup signals
        signal.signal(signal.SIGINT, self.sigint_handler)

    def build_launcher(self, clsname):
        """import and instantiate a Launcher based on importstring"""
        if '.' not in clsname:
            # not a module, presume it's the raw name in apps.launcher
            clsname = 'IPython.parallel.apps.launcher.' + clsname
        # print repr(clsname)
        try:
            klass = import_item(clsname)
        except (ImportError, KeyError):
            self.log.fatal("Could not import launcher class: %r" % clsname)
            self.exit(1)

        launcher = klass(work_dir=u'.', config=self.config, log=self.log)
        return launcher

    def start_engines(self):
        self.log.info("Starting %i engines" % self.n)
        self.engine_launcher.start(self.n, self.profile_dir.location)

    def stop_engines(self):
        self.log.info("Stopping Engines...")
        if self.engine_launcher.running:
            d = self.engine_launcher.stop()
            return d
        else:
            return None

    def stop_launchers(self, r=None):
        if not self._stopping:
            self._stopping = True
            self.log.error("IPython cluster: stopping")
            self.stop_engines()
            # Wait a few seconds to let things shut down.
            dc = ioloop.DelayedCallback(self.loop.stop, 4000, self.loop)
            dc.start()

    def sigint_handler(self, signum, frame):
        self.log.debug("SIGINT received, stopping launchers...")
        self.stop_launchers()

    def start_logging(self):
        # Remove old log files of the controller and engine
        if self.clean_logs:
            log_dir = self.profile_dir.log_dir
            for f in os.listdir(log_dir):
                if re.match(r'ip(engine|controller)z-\d+\.(log|err|out)', f):
                    os.remove(os.path.join(log_dir, f))
        # This will remove old log files for ipcluster itself
        # super(IPBaseParallelApplication, self).start_logging()

    def start(self):
        """Start the app for the engines subcommand."""
        self.log.info("IPython cluster: started")
        # First see if the cluster is already running

        # Now log and daemonize
        self.log.info('Starting engines with [daemon=%r]' % self.daemonize)
        # TODO: Get daemonize working on Windows or as a Windows Server.
        if self.daemonize:
            if os.name == 'posix':
                daemonize()

        dc = ioloop.DelayedCallback(self.start_engines, 0, self.loop)
        dc.start()
        # Now write the new pid file AFTER our new forked pid is active.
        # self.write_pid_file()
        try:
            self.loop.start()
        except KeyboardInterrupt:
            pass
        except zmq.ZMQError as e:
            if e.errno == errno.EINTR:
                pass
            else:
                raise
Esempio n. 25
0
 class A(HasTraits):
     i = Int()
     x = Float()
Esempio n. 26
0
class WinHPCJob(Configurable):

    job_id = Unicode('')
    job_name = Unicode('MyJob', config=True)
    min_cores = Int(1, config=True)
    max_cores = Int(1, config=True)
    min_sockets = Int(1, config=True)
    max_sockets = Int(1, config=True)
    min_nodes = Int(1, config=True)
    max_nodes = Int(1, config=True)
    unit_type = Unicode("Core", config=True)
    auto_calculate_min = Bool(True, config=True)
    auto_calculate_max = Bool(True, config=True)
    run_until_canceled = Bool(False, config=True)
    is_exclusive = Bool(False, config=True)
    username = Unicode(find_username(), config=True)
    job_type = Unicode('Batch', config=True)
    priority = Enum(
        ('Lowest', 'BelowNormal', 'Normal', 'AboveNormal', 'Highest'),
        default_value='Highest',
        config=True)
    requested_nodes = Unicode('', config=True)
    project = Unicode('IPython', config=True)
    xmlns = Unicode('http://schemas.microsoft.com/HPCS2008/scheduler/')
    version = Unicode("2.000")
    tasks = List([])

    @property
    def owner(self):
        return self.username

    def _write_attr(self, root, attr, key):
        s = as_str(getattr(self, attr, ''))
        if s:
            root.set(key, s)

    def as_element(self):
        # We have to add _A_ type things to get the right order than
        # the MSFT XML parser expects.
        root = ET.Element('Job')
        self._write_attr(root, 'version', '_A_Version')
        self._write_attr(root, 'job_name', '_B_Name')
        self._write_attr(root, 'unit_type', '_C_UnitType')
        self._write_attr(root, 'min_cores', '_D_MinCores')
        self._write_attr(root, 'max_cores', '_E_MaxCores')
        self._write_attr(root, 'min_sockets', '_F_MinSockets')
        self._write_attr(root, 'max_sockets', '_G_MaxSockets')
        self._write_attr(root, 'min_nodes', '_H_MinNodes')
        self._write_attr(root, 'max_nodes', '_I_MaxNodes')
        self._write_attr(root, 'run_until_canceled', '_J_RunUntilCanceled')
        self._write_attr(root, 'is_exclusive', '_K_IsExclusive')
        self._write_attr(root, 'username', '_L_UserName')
        self._write_attr(root, 'job_type', '_M_JobType')
        self._write_attr(root, 'priority', '_N_Priority')
        self._write_attr(root, 'requested_nodes', '_O_RequestedNodes')
        self._write_attr(root, 'auto_calculate_max', '_P_AutoCalculateMax')
        self._write_attr(root, 'auto_calculate_min', '_Q_AutoCalculateMin')
        self._write_attr(root, 'project', '_R_Project')
        self._write_attr(root, 'owner', '_S_Owner')
        self._write_attr(root, 'xmlns', '_T_xmlns')
        dependencies = ET.SubElement(root, "Dependencies")
        etasks = ET.SubElement(root, "Tasks")
        for t in self.tasks:
            etasks.append(t.as_element())
        return root

    def tostring(self):
        """Return the string representation of the job description XML."""
        root = self.as_element()
        indent(root)
        txt = ET.tostring(root, encoding="utf-8")
        # Now remove the tokens used to order the attributes.
        txt = re.sub(r'_[A-Z]_', '', txt)
        txt = '<?xml version="1.0" encoding="utf-8"?>\n' + txt
        return txt

    def write(self, filename):
        """Write the XML job description to a file."""
        txt = self.tostring()
        with open(filename, 'w') as f:
            f.write(txt)

    def add_task(self, task):
        """Add a task to the job.

        Parameters
        ----------
        task : :class:`WinHPCTask`
            The task object to add.
        """
        self.tasks.append(task)
Esempio n. 27
0
class TupleTrait(HasTraits):

    value = Tuple(Int(allow_none=True))
Esempio n. 28
0
class Kernel(SessionFactory):

    #---------------------------------------------------------------------------
    # Kernel interface
    #---------------------------------------------------------------------------
    
    # kwargs:
    int_id = Int(-1, config=True)
    user_ns = Dict(config=True)
    exec_lines = List(config=True)
    
    control_stream = Instance(zmqstream.ZMQStream)
    task_stream = Instance(zmqstream.ZMQStream)
    iopub_stream = Instance(zmqstream.ZMQStream)
    client = Instance('IPython.parallel.Client')
    
    # internals
    shell_streams = List()
    compiler = Instance(CommandCompiler, (), {})
    completer = Instance(KernelCompleter)
    
    aborted = Set()
    shell_handlers = Dict()
    control_handlers = Dict()
    
    def _set_prefix(self):
        self.prefix = "engine.%s"%self.int_id
    
    def _connect_completer(self):
        self.completer = KernelCompleter(self.user_ns)
    
    def __init__(self, **kwargs):
        super(Kernel, self).__init__(**kwargs)
        self._set_prefix()
        self._connect_completer()
        
        self.on_trait_change(self._set_prefix, 'id')
        self.on_trait_change(self._connect_completer, 'user_ns')
        
        # Build dict of handlers for message types
        for msg_type in ['execute_request', 'complete_request', 'apply_request', 
                'clear_request']:
            self.shell_handlers[msg_type] = getattr(self, msg_type)
        
        for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
            self.control_handlers[msg_type] = getattr(self, msg_type)
        
        self._initial_exec_lines()
    
    def _wrap_exception(self, method=None):
        e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
        content=wrap_exception(e_info)
        return content
    
    def _initial_exec_lines(self):
        s = _Passer()
        content = dict(silent=True, user_variable=[],user_expressions=[])
        for line in self.exec_lines:
            self.log.debug("executing initialization: %s"%line)
            content.update({'code':line})
            msg = self.session.msg('execute_request', content)
            self.execute_request(s, [], msg)
        
        
    #-------------------- control handlers -----------------------------
    def abort_queues(self):
        for stream in self.shell_streams:
            if stream:
                self.abort_queue(stream)
    
    def abort_queue(self, stream):
        while True:
            try:
                msg = self.session.recv(stream, zmq.NOBLOCK,content=True)
            except zmq.ZMQError as e:
                if e.errno == zmq.EAGAIN:
                    break
                else:
                    return
            else:
                if msg is None:
                    return
                else:
                    idents,msg = msg
                
                # assert self.reply_socketly_socket.rcvmore(), "Unexpected missing message part."
                # msg = self.reply_socket.recv_json()
            self.log.info("Aborting:")
            self.log.info(str(msg))
            msg_type = msg['msg_type']
            reply_type = msg_type.split('_')[0] + '_reply'
            # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
            # self.reply_socket.send(ident,zmq.SNDMORE)
            # self.reply_socket.send_json(reply_msg)
            reply_msg = self.session.send(stream, reply_type, 
                        content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
            self.log.debug(str(reply_msg))
            # We need to wait a bit for requests to come in. This can probably
            # be set shorter for true asynchronous clients.
            time.sleep(0.05)
    
    def abort_request(self, stream, ident, parent):
        """abort a specifig msg by id"""
        msg_ids = parent['content'].get('msg_ids', None)
        if isinstance(msg_ids, basestring):
            msg_ids = [msg_ids]
        if not msg_ids:
            self.abort_queues()
        for mid in msg_ids:
            self.aborted.add(str(mid))
        
        content = dict(status='ok')
        reply_msg = self.session.send(stream, 'abort_reply', content=content, 
                parent=parent, ident=ident)
        self.log.debug(str(reply_msg))
    
    def shutdown_request(self, stream, ident, parent):
        """kill ourself.  This should really be handled in an external process"""
        try:
            self.abort_queues()
        except:
            content = self._wrap_exception('shutdown')
        else:
            content = dict(parent['content'])
            content['status'] = 'ok'
        msg = self.session.send(stream, 'shutdown_reply',
                                content=content, parent=parent, ident=ident)
        self.log.debug(str(msg))
        dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
        dc.start()
    
    def dispatch_control(self, msg):
        idents,msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.unpack_message(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Message", exc_info=True)
            return
        
        header = msg['header']
        msg_id = header['msg_id']
        
        handler = self.control_handlers.get(msg['msg_type'], None)
        if handler is None:
            self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
        else:
            handler(self.control_stream, idents, msg)
    

    #-------------------- queue helpers ------------------------------
    
    def check_dependencies(self, dependencies):
        if not dependencies:
            return True
        if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
            anyorall = dependencies[0]
            dependencies = dependencies[1]
        else:
            anyorall = 'all'
        results = self.client.get_results(dependencies,status_only=True)
        if results['status'] != 'ok':
            return False
        
        if anyorall == 'any':
            if not results['completed']:
                return False
        else:
            if results['pending']:
                return False
        
        return True
    
    def check_aborted(self, msg_id):
        return msg_id in self.aborted
    
    #-------------------- queue handlers -----------------------------
    
    def clear_request(self, stream, idents, parent):
        """Clear our namespace."""
        self.user_ns = {}
        msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent, 
                content = dict(status='ok'))
        self._initial_exec_lines()
    
    def execute_request(self, stream, ident, parent):
        self.log.debug('execute request %s'%parent)
        try:
            code = parent[u'content'][u'code']
        except:
            self.log.error("Got bad msg: %s"%parent, exc_info=True)
            return
        self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
                            ident='%s.pyin'%self.prefix)
        started = datetime.now().strftime(ISO8601)
        try:
            comp_code = self.compiler(code, '<zmq-kernel>')
            # allow for not overriding displayhook
            if hasattr(sys.displayhook, 'set_parent'):
                sys.displayhook.set_parent(parent)
                sys.stdout.set_parent(parent)
                sys.stderr.set_parent(parent)
            exec comp_code in self.user_ns, self.user_ns
        except:
            exc_content = self._wrap_exception('execute')
            # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
            self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
                            ident='%s.pyerr'%self.prefix)
            reply_content = exc_content
        else:
            reply_content = {'status' : 'ok'}
        
        reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent, 
                    ident=ident, subheader = dict(started=started))
        self.log.debug(str(reply_msg))
        if reply_msg['content']['status'] == u'error':
            self.abort_queues()

    def complete_request(self, stream, ident, parent):
        matches = {'matches' : self.complete(parent),
                   'status' : 'ok'}
        completion_msg = self.session.send(stream, 'complete_reply',
                                           matches, parent, ident)
        # print >> sys.__stdout__, completion_msg

    def complete(self, msg):
        return self.completer.complete(msg.content.line, msg.content.text)
    
    def apply_request(self, stream, ident, parent):
        # flush previous reply, so this request won't block it
        stream.flush(zmq.POLLOUT)
        
        try:
            content = parent[u'content']
            bufs = parent[u'buffers']
            msg_id = parent['header']['msg_id']
            # bound = parent['header'].get('bound', False)
        except:
            self.log.error("Got bad msg: %s"%parent, exc_info=True)
            return
        # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
        # self.iopub_stream.send(pyin_msg)
        # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
        sub = {'dependencies_met' : True, 'engine' : self.ident,
                'started': datetime.now().strftime(ISO8601)}
        try:
            # allow for not overriding displayhook
            if hasattr(sys.displayhook, 'set_parent'):
                sys.displayhook.set_parent(parent)
                sys.stdout.set_parent(parent)
                sys.stderr.set_parent(parent)
            # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
            working = self.user_ns
            # suffix = 
            prefix = "_"+str(msg_id).replace("-","")+"_"
            
            f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
            # if bound:
            #     bound_ns = Namespace(working)
            #     args = [bound_ns]+list(args)

            fname = getattr(f, '__name__', 'f')
            
            fname = prefix+"f"
            argname = prefix+"args"
            kwargname = prefix+"kwargs"
            resultname = prefix+"result"
            
            ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
            # print ns
            working.update(ns)
            code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
            try:
                exec code in working,working
                result = working.get(resultname)
            finally:
                for key in ns.iterkeys():
                    working.pop(key)
            # if bound:
            #     working.update(bound_ns)
            
            packed_result,buf = serialize_object(result)
            result_buf = [packed_result]+buf
        except:
            exc_content = self._wrap_exception('apply')
            # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
            self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
                                ident='%s.pyerr'%self.prefix)
            reply_content = exc_content
            result_buf = []
            
            if exc_content['ename'] == 'UnmetDependency':
                sub['dependencies_met'] = False
        else:
            reply_content = {'status' : 'ok'}
        
        # put 'ok'/'error' status in header, for scheduler introspection:
        sub['status'] = reply_content['status']
        
        reply_msg = self.session.send(stream, u'apply_reply', reply_content, 
                    parent=parent, ident=ident,buffers=result_buf, subheader=sub)
        
        # flush i/o
        # should this be before reply_msg is sent, like in the single-kernel code, 
        # or should nothing get in the way of real results?
        sys.stdout.flush()
        sys.stderr.flush()
    
    def dispatch_queue(self, stream, msg):
        self.control_stream.flush()
        idents,msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.unpack_message(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Message", exc_info=True)
            return
            
        
        header = msg['header']
        msg_id = header['msg_id']
        if self.check_aborted(msg_id):
            self.aborted.remove(msg_id)
            # is it safe to assume a msg_id will not be resubmitted?
            reply_type = msg['msg_type'].split('_')[0] + '_reply'
            status = {'status' : 'aborted'}
            reply_msg = self.session.send(stream, reply_type, subheader=status,
                        content=status, parent=msg, ident=idents)
            return
        handler = self.shell_handlers.get(msg['msg_type'], None)
        if handler is None:
            self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
        else:
            handler(stream, idents, msg)
    
    def start(self):
        #### stream mode:
        if self.control_stream:
            self.control_stream.on_recv(self.dispatch_control, copy=False)
            self.control_stream.on_err(printer)
        
        def make_dispatcher(stream):
            def dispatcher(msg):
                return self.dispatch_queue(stream, msg)
            return dispatcher
        
        for s in self.shell_streams:
            s.on_recv(make_dispatcher(s), copy=False)
            s.on_err(printer)
        
        if self.iopub_stream:
            self.iopub_stream.on_err(printer)
Esempio n. 29
0
 class B(HasTraits):
     count = Int()
Esempio n. 30
0
class HubFactory(RegistrationFactory):
    """The Configurable for setting up a Hub."""

    # name of a scheduler scheme
    scheme = Str('leastload', config=True)

    # port-pairs for monitoredqueues:
    hb = Instance(list, config=True)

    def _hb_default(self):
        return select_random_ports(2)

    mux = Instance(list, config=True)

    def _mux_default(self):
        return select_random_ports(2)

    task = Instance(list, config=True)

    def _task_default(self):
        return select_random_ports(2)

    control = Instance(list, config=True)

    def _control_default(self):
        return select_random_ports(2)

    iopub = Instance(list, config=True)

    def _iopub_default(self):
        return select_random_ports(2)

    # single ports:
    mon_port = Instance(int, config=True)

    def _mon_port_default(self):
        return select_random_ports(1)[0]

    notifier_port = Instance(int, config=True)

    def _notifier_port_default(self):
        return select_random_ports(1)[0]

    ping = Int(1000, config=True)  # ping frequency

    engine_ip = CStr('127.0.0.1', config=True)
    engine_transport = CStr('tcp', config=True)

    client_ip = CStr('127.0.0.1', config=True)
    client_transport = CStr('tcp', config=True)

    monitor_ip = CStr('127.0.0.1', config=True)
    monitor_transport = CStr('tcp', config=True)

    monitor_url = CStr('')

    db_class = CStr('IPython.parallel.controller.dictdb.DictDB', config=True)

    # not configurable
    db = Instance('IPython.parallel.controller.dictdb.BaseDB')
    heartmonitor = Instance(
        'IPython.parallel.controller.heartmonitor.HeartMonitor')
    subconstructors = List()
    _constructed = Bool(False)

    def _ip_changed(self, name, old, new):
        self.engine_ip = new
        self.client_ip = new
        self.monitor_ip = new
        self._update_monitor_url()

    def _update_monitor_url(self):
        self.monitor_url = "%s://%s:%i" % (self.monitor_transport,
                                           self.monitor_ip, self.mon_port)

    def _transport_changed(self, name, old, new):
        self.engine_transport = new
        self.client_transport = new
        self.monitor_transport = new
        self._update_monitor_url()

    def __init__(self, **kwargs):
        super(HubFactory, self).__init__(**kwargs)
        self._update_monitor_url()
        # self.on_trait_change(self._sync_ips, 'ip')
        # self.on_trait_change(self._sync_transports, 'transport')
        self.subconstructors.append(self.construct_hub)

    def construct(self):
        assert not self._constructed, "already constructed!"

        for subc in self.subconstructors:
            subc()

        self._constructed = True

    def start(self):
        assert self._constructed, "must be constructed by self.construct() first!"
        self.heartmonitor.start()
        self.log.info("Heartmonitor started")

    def construct_hub(self):
        """construct"""
        client_iface = "%s://%s:" % (self.client_transport,
                                     self.client_ip) + "%i"
        engine_iface = "%s://%s:" % (self.engine_transport,
                                     self.engine_ip) + "%i"

        ctx = self.context
        loop = self.loop

        # Registrar socket
        q = ZMQStream(ctx.socket(zmq.XREP), loop)
        q.bind(client_iface % self.regport)
        self.log.info("Hub listening on %s for registration." %
                      (client_iface % self.regport))
        if self.client_ip != self.engine_ip:
            q.bind(engine_iface % self.regport)
            self.log.info("Hub listening on %s for registration." %
                          (engine_iface % self.regport))

        ### Engine connections ###

        # heartbeat
        hpub = ctx.socket(zmq.PUB)
        hpub.bind(engine_iface % self.hb[0])
        hrep = ctx.socket(zmq.XREP)
        hrep.bind(engine_iface % self.hb[1])
        self.heartmonitor = HeartMonitor(loop=loop,
                                         pingstream=ZMQStream(hpub, loop),
                                         pongstream=ZMQStream(hrep, loop),
                                         period=self.ping,
                                         logname=self.log.name)

        ### Client connections ###
        # Notifier socket
        n = ZMQStream(ctx.socket(zmq.PUB), loop)
        n.bind(client_iface % self.notifier_port)

        ### build and launch the queues ###

        # monitor socket
        sub = ctx.socket(zmq.SUB)
        sub.setsockopt(zmq.SUBSCRIBE, "")
        sub.bind(self.monitor_url)
        sub.bind('inproc://monitor')
        sub = ZMQStream(sub, loop)

        # connect the db
        self.log.info('Hub using DB backend: %r' % (self.db_class.split()[-1]))
        # cdir = self.config.Global.cluster_dir
        self.db = import_item(self.db_class)(session=self.session.session,
                                             config=self.config)
        time.sleep(.25)

        # build connection dicts
        self.engine_info = {
            'control': engine_iface % self.control[1],
            'mux': engine_iface % self.mux[1],
            'heartbeat':
            (engine_iface % self.hb[0], engine_iface % self.hb[1]),
            'task': engine_iface % self.task[1],
            'iopub': engine_iface % self.iopub[1],
            # 'monitor' : engine_iface%self.mon_port,
        }

        self.client_info = {
            'control': client_iface % self.control[0],
            'mux': client_iface % self.mux[0],
            'task': (self.scheme, client_iface % self.task[0]),
            'iopub': client_iface % self.iopub[0],
            'notification': client_iface % self.notifier_port
        }
        self.log.debug("Hub engine addrs: %s" % self.engine_info)
        self.log.debug("Hub client addrs: %s" % self.client_info)
        self.hub = Hub(loop=loop,
                       session=self.session,
                       monitor=sub,
                       heartmonitor=self.heartmonitor,
                       query=q,
                       notifier=n,
                       db=self.db,
                       engine_info=self.engine_info,
                       client_info=self.client_info,
                       logname=self.log.name)