Esempio n. 1
0
class TaskScheduler(SessionFactory):
    """Python TaskScheduler object.
    
    This is the simplest object that supports msg_id based
    DAG dependencies. *Only* task msg_ids are checked, not
    msg_ids of jobs submitted via the MUX queue.
    
    """
    
    hwm = Int(0, config=True, shortname='hwm',
        help="""specify the High Water Mark (HWM) for the downstream
        socket in the Task scheduler. This is the maximum number
        of allowed outstanding tasks on each engine."""
    )
    scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
        'leastload', config=True, shortname='scheme', allow_none=False,
        help="""select the task scheduler scheme  [default: Python LRU]
        Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
    )
    def _scheme_name_changed(self, old, new):
        self.log.debug("Using scheme %r"%new)
        self.scheme = globals()[new]
    
    # input arguments:
    scheme = Instance(FunctionType) # function for determining the destination
    def _scheme_default(self):
        return leastload
    client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
    engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
    notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
    mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
    
    # internals:
    graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
    retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
    # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
    depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
    pending = Dict() # dict by engine_uuid of submitted tasks
    completed = Dict() # dict by engine_uuid of completed tasks
    failed = Dict() # dict by engine_uuid of failed tasks
    destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
    clients = Dict() # dict by msg_id for who submitted the task
    targets = List() # list of target IDENTs
    loads = List() # list of engine loads
    # full = Set() # set of IDENTs that have HWM outstanding tasks
    all_completed = Set() # set of all completed tasks
    all_failed = Set() # set of all failed tasks
    all_done = Set() # set of all finished tasks=union(completed,failed)
    all_ids = Set() # set of all submitted task IDs
    blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
    auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
    
    ident = CBytes() # ZMQ identity. This should just be self.session.session
                     # but ensure Bytes
    def _ident_default(self):
        return asbytes(self.session.session)
    
    def start(self):
        self.engine_stream.on_recv(self.dispatch_result, copy=False)
        self._notification_handlers = dict(
            registration_notification = self._register_engine,
            unregistration_notification = self._unregister_engine
        )
        self.notifier_stream.on_recv(self.dispatch_notification)
        self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
        self.auditor.start()
        self.log.info("Scheduler started [%s]"%self.scheme_name)
    
    def resume_receiving(self):
        """Resume accepting jobs."""
        self.client_stream.on_recv(self.dispatch_submission, copy=False)
    
    def stop_receiving(self):
        """Stop accepting jobs while there are no engines.
        Leave them in the ZMQ queue."""
        self.client_stream.on_recv(None)
    
    #-----------------------------------------------------------------------
    # [Un]Registration Handling
    #-----------------------------------------------------------------------
    
    def dispatch_notification(self, msg):
        """dispatch register/unregister events."""
        try:
            idents,msg = self.session.feed_identities(msg)
        except ValueError:
            self.log.warn("task::Invalid Message: %r",msg)
            return
        try:
            msg = self.session.unpack_message(msg)
        except ValueError:
            self.log.warn("task::Unauthorized message from: %r"%idents)
            return
        
        msg_type = msg['msg_type']
        
        handler = self._notification_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("Unhandled message type: %r"%msg_type)
        else:
            try:
                handler(asbytes(msg['content']['queue']))
            except Exception:
                self.log.error("task::Invalid notification msg: %r",msg)
    
    def _register_engine(self, uid):
        """New engine with ident `uid` became available."""
        # head of the line:
        self.targets.insert(0,uid)
        self.loads.insert(0,0)

        # initialize sets
        self.completed[uid] = set()
        self.failed[uid] = set()
        self.pending[uid] = {}
        if len(self.targets) == 1:
            self.resume_receiving()
        # rescan the graph:
        self.update_graph(None)

    def _unregister_engine(self, uid):
        """Existing engine with ident `uid` became unavailable."""
        if len(self.targets) == 1:
            # this was our only engine
            self.stop_receiving()
        
        # handle any potentially finished tasks:
        self.engine_stream.flush()
        
        # don't pop destinations, because they might be used later
        # map(self.destinations.pop, self.completed.pop(uid))
        # map(self.destinations.pop, self.failed.pop(uid))

        # prevent this engine from receiving work
        idx = self.targets.index(uid)
        self.targets.pop(idx)
        self.loads.pop(idx)
        
        # wait 5 seconds before cleaning up pending jobs, since the results might
        # still be incoming
        if self.pending[uid]:
            dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
            dc.start()
        else:
            self.completed.pop(uid)
            self.failed.pop(uid)

    
    def handle_stranded_tasks(self, engine):
        """Deal with jobs resident in an engine that died."""
        lost = self.pending[engine]
        for msg_id in lost.keys():
            if msg_id not in self.pending[engine]:
                # prevent double-handling of messages
                continue

            raw_msg = lost[msg_id][0]
            idents,msg = self.session.feed_identities(raw_msg, copy=False)
            parent = self.session.unpack(msg[1].bytes)
            idents = [engine, idents[0]]

            # build fake error reply
            try:
                raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
            except:
                content = error.wrap_exception()
            msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
            raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
            # and dispatch it
            self.dispatch_result(raw_reply)

        # finally scrub completed/failed lists
        self.completed.pop(engine)
        self.failed.pop(engine)
    
    
    #-----------------------------------------------------------------------
    # Job Submission
    #-----------------------------------------------------------------------
    def dispatch_submission(self, raw_msg):
        """Dispatch job submission to appropriate handlers."""
        # ensure targets up to date:
        self.notifier_stream.flush()
        try:
            idents, msg = self.session.feed_identities(raw_msg, copy=False)
            msg = self.session.unpack_message(msg, content=False, copy=False)
        except Exception:
            self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
            return
        
        
        # send to monitor
        self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
        
        header = msg['header']
        msg_id = header['msg_id']
        self.all_ids.add(msg_id)
        
        # get targets as a set of bytes objects
        # from a list of unicode objects
        targets = header.get('targets', [])
        targets = map(asbytes, targets)
        targets = set(targets)
            
        retries = header.get('retries', 0)
        self.retries[msg_id] = retries
        
        # time dependencies
        after = header.get('after', None)
        if after:
            after = Dependency(after)
            if after.all:
                if after.success:
                    after = Dependency(after.difference(self.all_completed),
                                success=after.success,
                                failure=after.failure,
                                all=after.all,
                    )
                if after.failure:
                    after = Dependency(after.difference(self.all_failed),
                                success=after.success,
                                failure=after.failure,
                                all=after.all,
                    )
            if after.check(self.all_completed, self.all_failed):
                # recast as empty set, if `after` already met,
                # to prevent unnecessary set comparisons
                after = MET
        else:
            after = MET
        
        # location dependencies
        follow = Dependency(header.get('follow', []))
        
        # turn timeouts into datetime objects:
        timeout = header.get('timeout', None)
        if timeout:
            timeout = datetime.now() + timedelta(0,timeout,0)
        
        args = [raw_msg, targets, after, follow, timeout]
        
        # validate and reduce dependencies:
        for dep in after,follow:
            if not dep: # empty dependency
                continue
            # check valid:
            if msg_id in dep or dep.difference(self.all_ids):
                self.depending[msg_id] = args
                return self.fail_unreachable(msg_id, error.InvalidDependency)
            # check if unreachable:
            if dep.unreachable(self.all_completed, self.all_failed):
                self.depending[msg_id] = args
                return self.fail_unreachable(msg_id)
        
        if after.check(self.all_completed, self.all_failed):
            # time deps already met, try to run
            if not self.maybe_run(msg_id, *args):
                # can't run yet
                if msg_id not in self.all_failed:
                    # could have failed as unreachable
                    self.save_unmet(msg_id, *args)
        else:
            self.save_unmet(msg_id, *args)
    
    def audit_timeouts(self):
        """Audit all waiting tasks for expired timeouts."""
        now = datetime.now()
        for msg_id in self.depending.keys():
            # must recheck, in case one failure cascaded to another:
            if msg_id in self.depending:
                raw,after,targets,follow,timeout = self.depending[msg_id]
                if timeout and timeout < now:
                    self.fail_unreachable(msg_id, error.TaskTimeout)
                
    def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
        """a task has become unreachable, send a reply with an ImpossibleDependency
        error."""
        if msg_id not in self.depending:
            self.log.error("msg %r already failed!", msg_id)
            return
        raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
        for mid in follow.union(after):
            if mid in self.graph:
                self.graph[mid].remove(msg_id)
        
        # FIXME: unpacking a message I've already unpacked, but didn't save:
        idents,msg = self.session.feed_identities(raw_msg, copy=False)
        header = self.session.unpack(msg[1].bytes)
        
        try:
            raise why()
        except:
            content = error.wrap_exception()
        
        self.all_done.add(msg_id)
        self.all_failed.add(msg_id)
        
        msg = self.session.send(self.client_stream, 'apply_reply', content, 
                                                parent=header, ident=idents)
        self.session.send(self.mon_stream, msg, ident=[b'outtask']+idents)
        
        self.update_graph(msg_id, success=False)
    
    def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
        """check location dependencies, and run if they are met."""
        blacklist = self.blacklist.setdefault(msg_id, set())
        if follow or targets or blacklist or self.hwm:
            # we need a can_run filter
            def can_run(idx):
                # check hwm
                if self.hwm and self.loads[idx] == self.hwm:
                    return False
                target = self.targets[idx]
                # check blacklist
                if target in blacklist:
                    return False
                # check targets
                if targets and target not in targets:
                    return False
                # check follow
                return follow.check(self.completed[target], self.failed[target])
            
            indices = filter(can_run, range(len(self.targets)))

            if not indices:
                # couldn't run
                if follow.all:
                    # check follow for impossibility
                    dests = set()
                    relevant = set()
                    if follow.success:
                        relevant = self.all_completed
                    if follow.failure:
                        relevant = relevant.union(self.all_failed)
                    for m in follow.intersection(relevant):
                        dests.add(self.destinations[m])
                    if len(dests) > 1:
                        self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
                        self.fail_unreachable(msg_id)
                        return False
                if targets:
                    # check blacklist+targets for impossibility
                    targets.difference_update(blacklist)
                    if not targets or not targets.intersection(self.targets):
                        self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
                        self.fail_unreachable(msg_id)
                        return False
                return False
        else:
            indices = None
            
        self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
        return True
            
    def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
        """Save a message for later submission when its dependencies are met."""
        self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
        # track the ids in follow or after, but not those already finished
        for dep_id in after.union(follow).difference(self.all_done):
            if dep_id not in self.graph:
                self.graph[dep_id] = set()
            self.graph[dep_id].add(msg_id)
    
    def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
        """Submit a task to any of a subset of our targets."""
        if indices:
            loads = [self.loads[i] for i in indices]
        else:
            loads = self.loads
        idx = self.scheme(loads)
        if indices:
            idx = indices[idx]
        target = self.targets[idx]
        # print (target, map(str, msg[:3]))
        # send job to the engine
        self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
        self.engine_stream.send_multipart(raw_msg, copy=False)
        # update load
        self.add_job(idx)
        self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
        # notify Hub
        content = dict(msg_id=msg_id, engine_id=target.decode('ascii'))
        self.session.send(self.mon_stream, 'task_destination', content=content, 
                        ident=[b'tracktask',self.ident])
        
    
    #-----------------------------------------------------------------------
    # Result Handling
    #-----------------------------------------------------------------------
    def dispatch_result(self, raw_msg):
        """dispatch method for result replies"""
        try:
            idents,msg = self.session.feed_identities(raw_msg, copy=False)
            msg = self.session.unpack_message(msg, content=False, copy=False)
            engine = idents[0]
            try:
                idx = self.targets.index(engine)
            except ValueError:
                pass # skip load-update for dead engines
            else:
                self.finish_job(idx)
        except Exception:
            self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
            return

        header = msg['header']
        parent = msg['parent_header']
        if header.get('dependencies_met', True):
            success = (header['status'] == 'ok')
            msg_id = parent['msg_id']
            retries = self.retries[msg_id]
            if not success and retries > 0:
                # failed
                self.retries[msg_id] = retries - 1
                self.handle_unmet_dependency(idents, parent)
            else:
                del self.retries[msg_id]
                # relay to client and update graph
                self.handle_result(idents, parent, raw_msg, success)
                # send to Hub monitor
                self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
        else:
            self.handle_unmet_dependency(idents, parent)
        
    def handle_result(self, idents, parent, raw_msg, success=True):
        """handle a real task result, either success or failure"""
        # first, relay result to client
        engine = idents[0]
        client = idents[1]
        # swap_ids for XREP-XREP mirror
        raw_msg[:2] = [client,engine]
        # print (map(str, raw_msg[:4]))
        self.client_stream.send_multipart(raw_msg, copy=False)
        # now, update our data structures
        msg_id = parent['msg_id']
        self.blacklist.pop(msg_id, None)
        self.pending[engine].pop(msg_id)
        if success:
            self.completed[engine].add(msg_id)
            self.all_completed.add(msg_id)
        else:
            self.failed[engine].add(msg_id)
            self.all_failed.add(msg_id)
        self.all_done.add(msg_id)
        self.destinations[msg_id] = engine
        
        self.update_graph(msg_id, success)
        
    def handle_unmet_dependency(self, idents, parent):
        """handle an unmet dependency"""
        engine = idents[0]
        msg_id = parent['msg_id']
        
        if msg_id not in self.blacklist:
            self.blacklist[msg_id] = set()
        self.blacklist[msg_id].add(engine)
        
        args = self.pending[engine].pop(msg_id)
        raw,targets,after,follow,timeout = args
        
        if self.blacklist[msg_id] == targets:
            self.depending[msg_id] = args
            self.fail_unreachable(msg_id)
        elif not self.maybe_run(msg_id, *args):
            # resubmit failed
            if msg_id not in self.all_failed:
                # put it back in our dependency tree
                self.save_unmet(msg_id, *args)
        
        if self.hwm:
            try:
                idx = self.targets.index(engine)
            except ValueError:
                pass # skip load-update for dead engines
            else:
                if self.loads[idx] == self.hwm-1:
                    self.update_graph(None)
        
        
    
    def update_graph(self, dep_id=None, success=True):
        """dep_id just finished. Update our dependency
        graph and submit any jobs that just became runable.
        
        Called with dep_id=None to update entire graph for hwm, but without finishing
        a task.
        """
        # print ("\n\n***********")
        # pprint (dep_id)
        # pprint (self.graph)
        # pprint (self.depending)
        # pprint (self.all_completed)
        # pprint (self.all_failed)
        # print ("\n\n***********\n\n")
        # update any jobs that depended on the dependency
        jobs = self.graph.pop(dep_id, [])

        # recheck *all* jobs if
        # a) we have HWM and an engine just become no longer full
        # or b) dep_id was given as None
        if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
            jobs = self.depending.keys()
        
        for msg_id in jobs:
            raw_msg, targets, after, follow, timeout = self.depending[msg_id]
            
            if after.unreachable(self.all_completed, self.all_failed)\
                    or follow.unreachable(self.all_completed, self.all_failed):
                self.fail_unreachable(msg_id)
            
            elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
                if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
                    
                    self.depending.pop(msg_id)
                    for mid in follow.union(after):
                        if mid in self.graph:
                            self.graph[mid].remove(msg_id)
    
    #----------------------------------------------------------------------
    # methods to be overridden by subclasses
    #----------------------------------------------------------------------
    
    def add_job(self, idx):
        """Called after self.targets[idx] just got the job with header.
        Override with subclasses.  The default ordering is simple LRU.
        The default loads are the number of outstanding jobs."""
        self.loads[idx] += 1
        for lis in (self.targets, self.loads):
            lis.append(lis.pop(idx))
            
    
    def finish_job(self, idx):
        """Called after self.targets[idx] just finished a job.
        Override with subclasses."""
        self.loads[idx] -= 1
Esempio n. 2
0
 class C(HasTraits):
     inst = Instance(Foo)
Esempio n. 3
0
 class A(HasTraits):
     inst = Instance(Foo())
Esempio n. 4
0
class NotebookApp(BaseIPythonApplication):

    name = 'ipython-notebook'

    description = """
        The IPython HTML Notebook.
        
        This launches a Tornado based HTML Notebook Server that serves up an
        HTML5/Javascript Notebook client.
    """
    examples = _examples
    aliases = aliases
    flags = flags

    classes = [
        KernelManager,
        ProfileDir,
        Session,
        MappingKernelManager,
        ContentsManager,
        FileContentsManager,
        NotebookNotary,
        KernelSpecManager,
    ]
    flags = Dict(flags)
    aliases = Dict(aliases)

    subcommands = dict(list=(NbserverListApp,
                             NbserverListApp.description.splitlines()[0]), )

    _log_formatter_cls = LogFormatter

    def _log_level_default(self):
        return logging.INFO

    def _log_datefmt_default(self):
        """Exclude date from default date format"""
        return "%H:%M:%S"

    def _log_format_default(self):
        """override default log format to include time"""
        return u"%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s]%(end_color)s %(message)s"

    # create requested profiles by default, if they don't exist:
    auto_create = Bool(True)

    # file to be opened in the notebook server
    file_to_run = Unicode('', config=True)

    # Network related information

    allow_origin = Unicode('',
                           config=True,
                           help="""Set the Access-Control-Allow-Origin header
        
        Use '*' to allow any origin to access your server.
        
        Takes precedence over allow_origin_pat.
        """)

    allow_origin_pat = Unicode(
        '',
        config=True,
        help=
        """Use a regular expression for the Access-Control-Allow-Origin header
        
        Requests from an origin matching the expression will get replies with:
        
            Access-Control-Allow-Origin: origin
        
        where `origin` is the origin of the request.
        
        Ignored if allow_origin is set.
        """)

    allow_credentials = Bool(
        False,
        config=True,
        help="Set the Access-Control-Allow-Credentials: true header")

    default_url = Unicode('/tree',
                          config=True,
                          help="The default URL to redirect to from `/`")

    ip = Unicode('localhost',
                 config=True,
                 help="The IP address the notebook server will listen on.")

    def _ip_default(self):
        """Return localhost if available, 127.0.0.1 otherwise.
        
        On some (horribly broken) systems, localhost cannot be bound.
        """
        s = socket.socket()
        try:
            s.bind(('localhost', 0))
        except socket.error as e:
            self.log.warn(
                "Cannot bind to localhost, using 127.0.0.1 as default ip\n%s",
                e)
            return '127.0.0.1'
        else:
            s.close()
            return 'localhost'

    def _ip_changed(self, name, old, new):
        if new == u'*': self.ip = u''

    port = Integer(8888,
                   config=True,
                   help="The port the notebook server will listen on.")
    port_retries = Integer(
        50,
        config=True,
        help=
        "The number of additional ports to try if the specified port is not available."
    )

    certfile = Unicode(
        u'',
        config=True,
        help="""The full path to an SSL/TLS certificate file.""")

    keyfile = Unicode(
        u'',
        config=True,
        help="""The full path to a private key file for usage with SSL/TLS.""")

    cookie_secret_file = Unicode(
        config=True, help="""The file where the cookie secret is stored.""")

    def _cookie_secret_file_default(self):
        if self.profile_dir is None:
            return ''
        return os.path.join(self.profile_dir.security_dir,
                            'notebook_cookie_secret')

    cookie_secret = Bytes(b'',
                          config=True,
                          help="""The random bytes used to secure cookies.
        By default this is a new random number every time you start the Notebook.
        Set it to a value in a config file to enable logins to persist across server sessions.
        
        Note: Cookie secrets should be kept private, do not share config files with
        cookie_secret stored in plaintext (you can read the value from a file).
        """)

    def _cookie_secret_default(self):
        if os.path.exists(self.cookie_secret_file):
            with io.open(self.cookie_secret_file, 'rb') as f:
                return f.read()
        else:
            secret = base64.encodestring(os.urandom(1024))
            self._write_cookie_secret_file(secret)
            return secret

    def _write_cookie_secret_file(self, secret):
        """write my secret to my secret_file"""
        self.log.info("Writing notebook server cookie secret to %s",
                      self.cookie_secret_file)
        with io.open(self.cookie_secret_file, 'wb') as f:
            f.write(secret)
        try:
            os.chmod(self.cookie_secret_file, 0o600)
        except OSError:
            self.log.warn("Could not set permissions on %s",
                          self.cookie_secret_file)

    password = Unicode(u'',
                       config=True,
                       help="""Hashed password to use for web authentication.

                      To generate, type in a python/IPython shell:

                        from IPython.lib import passwd; passwd()

                      The string should be of the form type:salt:hashed-password.
                      """)

    open_browser = Bool(True,
                        config=True,
                        help="""Whether to open in a browser after starting.
                        The specific browser used is platform dependent and
                        determined by the python standard library `webbrowser`
                        module, unless it is overridden using the --browser
                        (NotebookApp.browser) configuration option.
                        """)

    browser = Unicode(u'',
                      config=True,
                      help="""Specify what command to use to invoke a web
                      browser when opening the notebook. If not specified, the
                      default browser will be determined by the `webbrowser`
                      standard library module, which allows setting of the
                      BROWSER environment variable to override it.
                      """)

    webapp_settings = Dict(config=True,
                           help="DEPRECATED, use tornado_settings")

    def _webapp_settings_changed(self, name, old, new):
        self.log.warn(
            "\n    webapp_settings is deprecated, use tornado_settings.\n")
        self.tornado_settings = new

    tornado_settings = Dict(
        config=True,
        help="Supply overrides for the tornado.web.Application that the "
        "IPython notebook uses.")

    ssl_options = Dict(config=True,
                       help="""Supply SSL options for the tornado HTTPServer.
            See the tornado docs for details.""")

    jinja_environment_options = Dict(
        config=True,
        help="Supply extra arguments that will be passed to Jinja environment."
    )

    enable_mathjax = Bool(
        True,
        config=True,
        help="""Whether to enable MathJax for typesetting math/TeX

        MathJax is the javascript library IPython uses to render math/LaTeX. It is
        very large, so you may want to disable it if you have a slow internet
        connection, or for offline use of the notebook.

        When disabled, equations etc. will appear as their untransformed TeX source.
        """)

    def _enable_mathjax_changed(self, name, old, new):
        """set mathjax url to empty if mathjax is disabled"""
        if not new:
            self.mathjax_url = u''

    base_url = Unicode('/',
                       config=True,
                       help='''The base URL for the notebook server.

                               Leading and trailing slashes can be omitted,
                               and will automatically be added.
                               ''')

    def _base_url_changed(self, name, old, new):
        if not new.startswith('/'):
            self.base_url = '/' + new
        elif not new.endswith('/'):
            self.base_url = new + '/'

    base_project_url = Unicode('/',
                               config=True,
                               help="""DEPRECATED use base_url""")

    def _base_project_url_changed(self, name, old, new):
        self.log.warn("base_project_url is deprecated, use base_url")
        self.base_url = new

    extra_static_paths = List(
        Unicode,
        config=True,
        help="""Extra paths to search for serving static files.
        
        This allows adding javascript/css to be available from the notebook server machine,
        or overriding individual files in the IPython""")

    def _extra_static_paths_default(self):
        return [os.path.join(self.profile_dir.location, 'static')]

    @property
    def static_file_path(self):
        """return extra paths + the default location"""
        return self.extra_static_paths + [DEFAULT_STATIC_FILES_PATH]

    extra_template_paths = List(
        Unicode,
        config=True,
        help="""Extra paths to search for serving jinja templates.

        Can be used to override templates from IPython.html.templates.""")

    def _extra_template_paths_default(self):
        return []

    @property
    def template_file_path(self):
        """return extra paths + the default locations"""
        return self.extra_template_paths + DEFAULT_TEMPLATE_PATH_LIST

    extra_nbextensions_path = List(
        Unicode,
        config=True,
        help="""extra paths to look for Javascript notebook extensions""")

    @property
    def nbextensions_path(self):
        """The path to look for Javascript notebook extensions"""
        return self.extra_nbextensions_path + [
            os.path.join(get_ipython_dir(), 'nbextensions')
        ] + SYSTEM_NBEXTENSIONS_DIRS

    websocket_url = Unicode("",
                            config=True,
                            help="""The base URL for websockets,
        if it differs from the HTTP server (hint: it almost certainly doesn't).
        
        Should be in the form of an HTTP origin: ws[s]://hostname[:port]
        """)
    mathjax_url = Unicode("", config=True, help="""The url for MathJax.js.""")

    def _mathjax_url_default(self):
        if not self.enable_mathjax:
            return u''
        static_url_prefix = self.tornado_settings.get(
            "static_url_prefix", url_path_join(self.base_url, "static"))

        # try local mathjax, either in nbextensions/mathjax or static/mathjax
        for (url_prefix, search_path) in [
            (url_path_join(self.base_url,
                           "nbextensions"), self.nbextensions_path),
            (static_url_prefix, self.static_file_path),
        ]:
            self.log.debug("searching for local mathjax in %s", search_path)
            try:
                mathjax = filefind(os.path.join('mathjax', 'MathJax.js'),
                                   search_path)
            except IOError:
                continue
            else:
                url = url_path_join(url_prefix, u"mathjax/MathJax.js")
                self.log.info("Serving local MathJax from %s at %s", mathjax,
                              url)
                return url

        # no local mathjax, serve from CDN
        url = u"https://cdn.mathjax.org/mathjax/latest/MathJax.js"
        self.log.info("Using MathJax from CDN: %s", url)
        return url

    def _mathjax_url_changed(self, name, old, new):
        if new and not self.enable_mathjax:
            # enable_mathjax=False overrides mathjax_url
            self.mathjax_url = u''
        else:
            self.log.info("Using MathJax: %s", new)

    contents_manager_class = Type(default_value=FileContentsManager,
                                  klass=ContentsManager,
                                  config=True,
                                  help='The notebook manager class to use.')
    kernel_manager_class = Type(default_value=MappingKernelManager,
                                config=True,
                                help='The kernel manager class to use.')
    session_manager_class = Type(default_value=SessionManager,
                                 config=True,
                                 help='The session manager class to use.')
    cluster_manager_class = Type(default_value=ClusterManager,
                                 config=True,
                                 help='The cluster manager class to use.')

    config_manager_class = Type(default_value=ConfigManager,
                                config=True,
                                help='The config manager class to use')

    kernel_spec_manager = Instance(KernelSpecManager, allow_none=True)

    kernel_spec_manager_class = Type(default_value=KernelSpecManager,
                                     config=True,
                                     help="""
        The kernel spec manager class to use. Should be a subclass
        of `IPython.kernel.kernelspec.KernelSpecManager`.

        The Api of KernelSpecManager is provisional and might change
        without warning between this version of IPython and the next stable one.
        """)

    login_handler_class = Type(
        default_value=LoginHandler,
        klass=web.RequestHandler,
        config=True,
        help='The login handler class to use.',
    )

    logout_handler_class = Type(
        default_value=LogoutHandler,
        klass=web.RequestHandler,
        config=True,
        help='The logout handler class to use.',
    )

    trust_xheaders = Bool(
        False,
        config=True,
        help=
        ("Whether to trust or not X-Scheme/X-Forwarded-Proto and X-Real-Ip/X-Forwarded-For headers"
         "sent by the upstream reverse proxy. Necessary if the proxy handles SSL"
         ))

    info_file = Unicode()

    def _info_file_default(self):
        info_file = "nbserver-%s.json" % os.getpid()
        return os.path.join(self.profile_dir.security_dir, info_file)

    pylab = Unicode('disabled',
                    config=True,
                    help="""
        DISABLED: use %pylab or %matplotlib in the notebook to enable matplotlib.
        """)

    def _pylab_changed(self, name, old, new):
        """when --pylab is specified, display a warning and exit"""
        if new != 'warn':
            backend = ' %s' % new
        else:
            backend = ''
        self.log.error(
            "Support for specifying --pylab on the command line has been removed."
        )
        self.log.error(
            "Please use `%pylab{0}` or `%matplotlib{0}` in the notebook itself."
            .format(backend))
        self.exit(1)

    notebook_dir = Unicode(
        config=True, help="The directory to use for notebooks and kernels.")

    def _notebook_dir_default(self):
        if self.file_to_run:
            return os.path.dirname(os.path.abspath(self.file_to_run))
        else:
            return py3compat.getcwd()

    def _notebook_dir_changed(self, name, old, new):
        """Do a bit of validation of the notebook dir."""
        if not os.path.isabs(new):
            # If we receive a non-absolute path, make it absolute.
            self.notebook_dir = os.path.abspath(new)
            return
        if not os.path.isdir(new):
            raise TraitError("No such notebook dir: %r" % new)

        # setting App.notebook_dir implies setting notebook and kernel dirs as well
        self.config.FileContentsManager.root_dir = new
        self.config.MappingKernelManager.root_dir = new

    server_extensions = List(
        Unicode(),
        config=True,
        help=(
            "Python modules to load as notebook server extensions. "
            "This is an experimental API, and may change in future releases."))

    reraise_server_extension_failures = Bool(
        False,
        config=True,
        help="Reraise exceptions encountered loading server extensions?",
    )

    def parse_command_line(self, argv=None):
        super(NotebookApp, self).parse_command_line(argv)

        if self.extra_args:
            arg0 = self.extra_args[0]
            f = os.path.abspath(arg0)
            self.argv.remove(arg0)
            if not os.path.exists(f):
                self.log.critical("No such file or directory: %s", f)
                self.exit(1)

            # Use config here, to ensure that it takes higher priority than
            # anything that comes from the profile.
            c = Config()
            if os.path.isdir(f):
                c.NotebookApp.notebook_dir = f
            elif os.path.isfile(f):
                c.NotebookApp.file_to_run = f
            self.update_config(c)

    def init_configurables(self):
        self.kernel_spec_manager = self.kernel_spec_manager_class(
            parent=self,
            ipython_dir=self.ipython_dir,
        )
        self.kernel_manager = self.kernel_manager_class(
            parent=self,
            log=self.log,
            connection_dir=self.profile_dir.security_dir,
        )
        self.contents_manager = self.contents_manager_class(
            parent=self,
            log=self.log,
        )
        self.session_manager = self.session_manager_class(
            parent=self,
            log=self.log,
            kernel_manager=self.kernel_manager,
            contents_manager=self.contents_manager,
        )
        self.cluster_manager = self.cluster_manager_class(
            parent=self,
            log=self.log,
        )

        self.config_manager = self.config_manager_class(
            parent=self,
            log=self.log,
            profile_dir=self.profile_dir.location,
        )

    def init_logging(self):
        # This prevents double log messages because tornado use a root logger that
        # self.log is a child of. The logging module dipatches log messages to a log
        # and all of its ancenstors until propagate is set to False.
        self.log.propagate = False

        for log in app_log, access_log, gen_log:
            # consistent log output name (NotebookApp instead of tornado.access, etc.)
            log.name = self.log.name
        # hook up tornado 3's loggers to our app handlers
        logger = logging.getLogger('tornado')
        logger.propagate = True
        logger.parent = self.log
        logger.setLevel(self.log.level)

    def init_webapp(self):
        """initialize tornado webapp and httpserver"""
        self.tornado_settings['allow_origin'] = self.allow_origin
        if self.allow_origin_pat:
            self.tornado_settings['allow_origin_pat'] = re.compile(
                self.allow_origin_pat)
        self.tornado_settings['allow_credentials'] = self.allow_credentials
        # ensure default_url starts with base_url
        if not self.default_url.startswith(self.base_url):
            self.default_url = url_path_join(self.base_url, self.default_url)

        self.web_app = NotebookWebApplication(
            self, self.kernel_manager, self.contents_manager,
            self.cluster_manager, self.session_manager,
            self.kernel_spec_manager, self.config_manager, self.log,
            self.base_url, self.default_url, self.tornado_settings,
            self.jinja_environment_options)
        ssl_options = self.ssl_options
        if self.certfile:
            ssl_options['certfile'] = self.certfile
        if self.keyfile:
            ssl_options['keyfile'] = self.keyfile
        if not ssl_options:
            # None indicates no SSL config
            ssl_options = None
        else:
            # Disable SSLv3, since its use is discouraged.
            ssl_options['ssl_version'] = ssl.PROTOCOL_TLSv1
        self.login_handler_class.validate_security(self,
                                                   ssl_options=ssl_options)
        self.http_server = httpserver.HTTPServer(self.web_app,
                                                 ssl_options=ssl_options,
                                                 xheaders=self.trust_xheaders)

        success = None
        for port in random_ports(self.port, self.port_retries + 1):
            try:
                self.http_server.listen(port, self.ip)
            except socket.error as e:
                if e.errno == errno.EADDRINUSE:
                    self.log.info(
                        'The port %i is already in use, trying another random port.'
                        % port)
                    continue
                elif e.errno in (errno.EACCES,
                                 getattr(errno, 'WSAEACCES', errno.EACCES)):
                    self.log.warn("Permission to listen on port %i denied" %
                                  port)
                    continue
                else:
                    raise
            else:
                self.port = port
                success = True
                break
        if not success:
            self.log.critical(
                'ERROR: the notebook server could not be started because '
                'no available port could be found.')
            self.exit(1)

    @property
    def display_url(self):
        ip = self.ip if self.ip else '[all ip addresses on your system]'
        return self._url(ip)

    @property
    def connection_url(self):
        ip = self.ip if self.ip else 'localhost'
        return self._url(ip)

    def _url(self, ip):
        proto = 'https' if self.certfile else 'http'
        return "%s://%s:%i%s" % (proto, ip, self.port, self.base_url)

    def init_terminals(self):
        try:
            from .terminal import initialize
            initialize(self.web_app)
            self.web_app.settings['terminals_available'] = True
        except ImportError as e:
            log = self.log.debug if sys.platform == 'win32' else self.log.warn
            log("Terminals not available (error was %s)", e)

    def init_signal(self):
        if not sys.platform.startswith('win'):
            signal.signal(signal.SIGINT, self._handle_sigint)
        signal.signal(signal.SIGTERM, self._signal_stop)
        if hasattr(signal, 'SIGUSR1'):
            # Windows doesn't support SIGUSR1
            signal.signal(signal.SIGUSR1, self._signal_info)
        if hasattr(signal, 'SIGINFO'):
            # only on BSD-based systems
            signal.signal(signal.SIGINFO, self._signal_info)

    def _handle_sigint(self, sig, frame):
        """SIGINT handler spawns confirmation dialog"""
        # register more forceful signal handler for ^C^C case
        signal.signal(signal.SIGINT, self._signal_stop)
        # request confirmation dialog in bg thread, to avoid
        # blocking the App
        thread = threading.Thread(target=self._confirm_exit)
        thread.daemon = True
        thread.start()

    def _restore_sigint_handler(self):
        """callback for restoring original SIGINT handler"""
        signal.signal(signal.SIGINT, self._handle_sigint)

    def _confirm_exit(self):
        """confirm shutdown on ^C
        
        A second ^C, or answering 'y' within 5s will cause shutdown,
        otherwise original SIGINT handler will be restored.
        
        This doesn't work on Windows.
        """
        info = self.log.info
        info('interrupted')
        print(self.notebook_info())
        sys.stdout.write("Shutdown this notebook server (y/[n])? ")
        sys.stdout.flush()
        r, w, x = select.select([sys.stdin], [], [], 5)
        if r:
            line = sys.stdin.readline()
            if line.lower().startswith('y') and 'n' not in line.lower():
                self.log.critical("Shutdown confirmed")
                ioloop.IOLoop.current().stop()
                return
        else:
            print("No answer for 5s:", end=' ')
        print("resuming operation...")
        # no answer, or answer is no:
        # set it back to original SIGINT handler
        # use IOLoop.add_callback because signal.signal must be called
        # from main thread
        ioloop.IOLoop.current().add_callback(self._restore_sigint_handler)

    def _signal_stop(self, sig, frame):
        self.log.critical("received signal %s, stopping", sig)
        ioloop.IOLoop.current().stop()

    def _signal_info(self, sig, frame):
        print(self.notebook_info())

    def init_components(self):
        """Check the components submodule, and warn if it's unclean"""
        status = submodule.check_submodule_status()
        if status == 'missing':
            self.log.warn(
                "components submodule missing, running `git submodule update`")
            submodule.update_submodules(submodule.ipython_parent())
        elif status == 'unclean':
            self.log.warn(
                "components submodule unclean, you may see 404s on static/components"
            )
            self.log.warn(
                "run `setup.py submodule` or `git submodule update` to update")

    def init_server_extensions(self):
        """Load any extensions specified by config.

        Import the module, then call the load_jupyter_server_extension function,
        if one exists.
        
        The extension API is experimental, and may change in future releases.
        """
        for modulename in self.server_extensions:
            try:
                mod = importlib.import_module(modulename)
                func = getattr(mod, 'load_jupyter_server_extension', None)
                if func is not None:
                    func(self)
            except Exception:
                if self.reraise_server_extension_failures:
                    raise
                self.log.warn("Error loading server extension %s",
                              modulename,
                              exc_info=True)

    @catch_config_error
    def initialize(self, argv=None):
        super(NotebookApp, self).initialize(argv)
        self.init_logging()
        self.init_configurables()
        self.init_components()
        self.init_webapp()
        self.init_terminals()
        self.init_signal()
        self.init_server_extensions()

    def cleanup_kernels(self):
        """Shutdown all kernels.
        
        The kernels will shutdown themselves when this process no longer exists,
        but explicit shutdown allows the KernelManagers to cleanup the connection files.
        """
        self.log.info('Shutting down kernels')
        self.kernel_manager.shutdown_all()

    def notebook_info(self):
        "Return the current working directory and the server url information"
        info = self.contents_manager.info_string() + "\n"
        info += "%d active kernels \n" % len(self.kernel_manager._kernels)
        return info + "The IPython Notebook is running at: %s" % self.display_url

    def server_info(self):
        """Return a JSONable dict of information about this server."""
        return {
            'url': self.connection_url,
            'hostname': self.ip if self.ip else 'localhost',
            'port': self.port,
            'secure': bool(self.certfile),
            'base_url': self.base_url,
            'notebook_dir': os.path.abspath(self.notebook_dir),
            'pid': os.getpid()
        }

    def write_server_info_file(self):
        """Write the result of server_info() to the JSON file info_file."""
        with open(self.info_file, 'w') as f:
            json.dump(self.server_info(), f, indent=2)

    def remove_server_info_file(self):
        """Remove the nbserver-<pid>.json file created for this server.
        
        Ignores the error raised when the file has already been removed.
        """
        try:
            os.unlink(self.info_file)
        except OSError as e:
            if e.errno != errno.ENOENT:
                raise

    def start(self):
        """ Start the IPython Notebook server app, after initialization
        
        This method takes no arguments so all configuration and initialization
        must be done prior to calling this method."""
        if self.subapp is not None:
            return self.subapp.start()

        info = self.log.info
        for line in self.notebook_info().split("\n"):
            info(line)
        info(
            "Use Control-C to stop this server and shut down all kernels (twice to skip confirmation)."
        )

        self.write_server_info_file()

        if self.open_browser or self.file_to_run:
            try:
                browser = webbrowser.get(self.browser or None)
            except webbrowser.Error as e:
                self.log.warn('No web browser found: %s.' % e)
                browser = None

            if self.file_to_run:
                if not os.path.exists(self.file_to_run):
                    self.log.critical("%s does not exist" % self.file_to_run)
                    self.exit(1)

                relpath = os.path.relpath(self.file_to_run, self.notebook_dir)
                uri = url_path_join('notebooks', *relpath.split(os.sep))
            else:
                uri = 'tree'
            if browser:
                b = lambda: browser.open(
                    url_path_join(self.connection_url, uri), new=2)
                threading.Thread(target=b).start()

        self.io_loop = ioloop.IOLoop.current()
        if sys.platform.startswith('win'):
            # add no-op to wake every 5s
            # to handle signals that may be ignored by the inner loop
            pc = ioloop.PeriodicCallback(lambda: None, 5000)
            pc.start()
        try:
            self.io_loop.start()
        except KeyboardInterrupt:
            info("Interrupted...")
        finally:
            self.cleanup_kernels()
            self.remove_server_info_file()

    def stop(self):
        def _stop():
            self.http_server.stop()
            self.io_loop.stop()

        self.io_loop.add_callback(_stop)
Esempio n. 5
0
 class A(HasTraits):
     inst = Instance(Foo, (10,))
Esempio n. 6
0
class Session(Configurable):
    """Object for handling serialization and sending of messages.

    The Session object handles building messages and sending them
    with ZMQ sockets or ZMQStream objects.  Objects can communicate with each
    other over the network via Session objects, and only need to work with the
    dict-based IPython message spec. The Session will handle
    serialization/deserialization, security, and metadata.

    Sessions support configurable serialiization via packer/unpacker traits,
    and signing with HMAC digests via the key/keyfile traits.

    Parameters
    ----------

    debug : bool
        whether to trigger extra debugging statements
    packer/unpacker : str : 'json', 'pickle' or import_string
        importstrings for methods to serialize message parts.  If just
        'json' or 'pickle', predefined JSON and pickle packers will be used.
        Otherwise, the entire importstring must be used.

        The functions must accept at least valid JSON input, and output *bytes*.

        For example, to use msgpack:
        packer = 'msgpack.packb', unpacker='msgpack.unpackb'
    pack/unpack : callables
        You can also set the pack/unpack callables for serialization directly.
    session : bytes
        the ID of this Session object.  The default is to generate a new UUID.
    username : unicode
        username added to message headers.  The default is to ask the OS.
    key : bytes
        The key used to initialize an HMAC signature.  If unset, messages
        will not be signed or checked.
    keyfile : filepath
        The file containing a key.  If this is set, `key` will be initialized
        to the contents of the file.

    """

    debug = Bool(False, config=True, help="""Debug output in the Session""")

    packer = DottedObjectName(
        'json',
        config=True,
        help="""The name of the packer for serializing messages.
            Should be one of 'json', 'pickle', or an import name
            for a custom callable serializer.""")

    def _packer_changed(self, name, old, new):
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
        else:
            self.pack = import_item(str(new))

    unpacker = DottedObjectName(
        'json',
        config=True,
        help="""The name of the unpacker for unserializing messages.
        Only used with custom functions for `packer`.""")

    def _unpacker_changed(self, name, old, new):
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
        else:
            self.unpack = import_item(str(new))

    session = CUnicode(u'',
                       config=True,
                       help="""The UUID identifying this session.""")

    def _session_default(self):
        u = unicode(uuid.uuid4())
        self.bsession = u.encode('ascii')
        return u

    def _session_changed(self, name, old, new):
        self.bsession = self.session.encode('ascii')

    # bsession is the session as bytes
    bsession = CBytes(b'')

    username = Unicode(
        os.environ.get('USER', u'username'),
        config=True,
        help="""Username for the Session. Default is your system username.""")

    # message signature related traits:

    key = CBytes(b'',
                 config=True,
                 help="""execution key, for extra authentication.""")

    def _key_changed(self, name, old, new):
        if new:
            self.auth = hmac.HMAC(new)
        else:
            self.auth = None

    auth = Instance(hmac.HMAC)
    digest_history = Set()

    keyfile = Unicode('',
                      config=True,
                      help="""path to file containing execution key.""")

    def _keyfile_changed(self, name, old, new):
        with open(new, 'rb') as f:
            self.key = f.read().strip()

    # serialization traits:

    pack = Any(default_packer)  # the actual packer function

    def _pack_changed(self, name, old, new):
        if not callable(new):
            raise TypeError("packer must be callable, not %s" % type(new))

    unpack = Any(default_unpacker)  # the actual packer function

    def _unpack_changed(self, name, old, new):
        # unpacker is not checked - it is assumed to be
        if not callable(new):
            raise TypeError("unpacker must be callable, not %s" % type(new))

    def __init__(self, **kwargs):
        """create a Session object

        Parameters
        ----------

        debug : bool
            whether to trigger extra debugging statements
        packer/unpacker : str : 'json', 'pickle' or import_string
            importstrings for methods to serialize message parts.  If just
            'json' or 'pickle', predefined JSON and pickle packers will be used.
            Otherwise, the entire importstring must be used.

            The functions must accept at least valid JSON input, and output
            *bytes*.

            For example, to use msgpack:
            packer = 'msgpack.packb', unpacker='msgpack.unpackb'
        pack/unpack : callables
            You can also set the pack/unpack callables for serialization
            directly.
        session : unicode (must be ascii)
            the ID of this Session object.  The default is to generate a new
            UUID.
        bsession : bytes
            The session as bytes
        username : unicode
            username added to message headers.  The default is to ask the OS.
        key : bytes
            The key used to initialize an HMAC signature.  If unset, messages
            will not be signed or checked.
        keyfile : filepath
            The file containing a key.  If this is set, `key` will be
            initialized to the contents of the file.
        """
        super(Session, self).__init__(**kwargs)
        self._check_packers()
        self.none = self.pack({})
        # ensure self._session_default() if necessary, so bsession is defined:
        self.session

    @property
    def msg_id(self):
        """always return new uuid"""
        return str(uuid.uuid4())

    def _check_packers(self):
        """check packers for binary data and datetime support."""
        pack = self.pack
        unpack = self.unpack

        # check simple serialization
        msg = dict(a=[1, 'hi'])
        try:
            packed = pack(msg)
        except Exception:
            raise ValueError("packer could not serialize a simple message")

        # ensure packed message is bytes
        if not isinstance(packed, bytes):
            raise ValueError("message packed to %r, but bytes are required" %
                             type(packed))

        # check that unpack is pack's inverse
        try:
            unpacked = unpack(packed)
        except Exception:
            raise ValueError("unpacker could not handle the packer's output")

        # check datetime support
        msg = dict(t=datetime.now())
        try:
            unpacked = unpack(pack(msg))
        except Exception:
            self.pack = lambda o: pack(squash_dates(o))
            self.unpack = lambda s: extract_dates(unpack(s))

    def msg_header(self, msg_type):
        return msg_header(self.msg_id, msg_type, self.username, self.session)

    def msg(self,
            msg_type,
            content=None,
            parent=None,
            subheader=None,
            header=None):
        """Return the nested message dict.

        This format is different from what is sent over the wire. The
        serialize/unserialize methods converts this nested message dict to the wire
        format, which is a list of message parts.
        """
        msg = {}
        header = self.msg_header(msg_type) if header is None else header
        msg['header'] = header
        msg['msg_id'] = header['msg_id']
        msg['msg_type'] = header['msg_type']
        msg['parent_header'] = {} if parent is None else extract_header(parent)
        msg['content'] = {} if content is None else content
        sub = {} if subheader is None else subheader
        msg['header'].update(sub)
        return msg

    def sign(self, msg_list):
        """Sign a message with HMAC digest. If no auth, return b''.

        Parameters
        ----------
        msg_list : list
            The [p_header,p_parent,p_content] part of the message list.
        """
        if self.auth is None:
            return b''
        h = self.auth.copy()
        for m in msg_list:
            h.update(m)
        return str_to_bytes(h.hexdigest())

    def serialize(self, msg, ident=None):
        """Serialize the message components to bytes.

        This is roughly the inverse of unserialize. The serialize/unserialize
        methods work with full message lists, whereas pack/unpack work with
        the individual message parts in the message list.

        Parameters
        ----------
        msg : dict or Message
            The nexted message dict as returned by the self.msg method.

        Returns
        -------
        msg_list : list
            The list of bytes objects to be sent with the format:
            [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
             buffer1,buffer2,...]. In this list, the p_* entities are
            the packed or serialized versions, so if JSON is used, these
            are utf8 encoded JSON strings.
        """
        content = msg.get('content', {})
        if content is None:
            content = self.none
        elif isinstance(content, dict):
            content = self.pack(content)
        elif isinstance(content, bytes):
            # content is already packed, as in a relayed message
            pass
        elif isinstance(content, unicode):
            # should be bytes, but JSON often spits out unicode
            content = content.encode('utf8')
        else:
            raise TypeError("Content incorrect type: %s" % type(content))

        real_message = [
            self.pack(msg['header']),
            self.pack(msg['parent_header']), content
        ]

        to_send = []

        if isinstance(ident, list):
            # accept list of idents
            to_send.extend(ident)
        elif ident is not None:
            to_send.append(ident)
        to_send.append(DELIM)

        signature = self.sign(real_message)
        to_send.append(signature)

        to_send.extend(real_message)

        return to_send

    def send(self,
             stream,
             msg_or_type,
             content=None,
             parent=None,
             ident=None,
             buffers=None,
             subheader=None,
             track=False,
             header=None):
        """Build and send a message via stream or socket.

        The message format used by this function internally is as follows:

        [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
         buffer1,buffer2,...]

        The serialize/unserialize methods convert the nested message dict into this
        format.

        Parameters
        ----------

        stream : zmq.Socket or ZMQStream
            The socket-like object used to send the data.
        msg_or_type : str or Message/dict
            Normally, msg_or_type will be a msg_type unless a message is being
            sent more than once. If a header is supplied, this can be set to
            None and the msg_type will be pulled from the header.

        content : dict or None
            The content of the message (ignored if msg_or_type is a message).
        header : dict or None
            The header dict for the message (ignores if msg_to_type is a message).
        parent : Message or dict or None
            The parent or parent header describing the parent of this message
            (ignored if msg_or_type is a message).
        ident : bytes or list of bytes
            The zmq.IDENTITY routing path.
        subheader : dict or None
            Extra header keys for this message's header (ignored if msg_or_type
            is a message).
        buffers : list or None
            The already-serialized buffers to be appended to the message.
        track : bool
            Whether to track.  Only for use with Sockets, because ZMQStream
            objects cannot track messages.

        Returns
        -------
        msg : dict
            The constructed message.
        (msg,tracker) : (dict, MessageTracker)
            if track=True, then a 2-tuple will be returned,
            the first element being the constructed
            message, and the second being the MessageTracker

        """

        if not isinstance(stream, (zmq.Socket, ZMQStream)):
            raise TypeError("stream must be Socket or ZMQStream, not %r" %
                            type(stream))
        elif track and isinstance(stream, ZMQStream):
            raise TypeError("ZMQStream cannot track messages")

        if isinstance(msg_or_type, (Message, dict)):
            # We got a Message or message dict, not a msg_type so don't
            # build a new Message.
            msg = msg_or_type
        else:
            msg = self.msg(msg_or_type,
                           content=content,
                           parent=parent,
                           subheader=subheader,
                           header=header)

        buffers = [] if buffers is None else buffers
        to_send = self.serialize(msg, ident)
        flag = 0
        if buffers:
            flag = zmq.SNDMORE
            _track = False
        else:
            _track = track
        if track:
            tracker = stream.send_multipart(to_send,
                                            flag,
                                            copy=False,
                                            track=_track)
        else:
            tracker = stream.send_multipart(to_send, flag, copy=False)
        for b in buffers[:-1]:
            stream.send(b, flag, copy=False)
        if buffers:
            if track:
                tracker = stream.send(buffers[-1], copy=False, track=track)
            else:
                tracker = stream.send(buffers[-1], copy=False)

        # omsg = Message(msg)
        if self.debug:
            pprint.pprint(msg)
            pprint.pprint(to_send)
            pprint.pprint(buffers)

        msg['tracker'] = tracker

        return msg

    def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
        """Send a raw message via ident path.

        This method is used to send a already serialized message.

        Parameters
        ----------
        stream : ZMQStream or Socket
            The ZMQ stream or socket to use for sending the message.
        msg_list : list
            The serialized list of messages to send. This only includes the
            [p_header,p_parent,p_content,buffer1,buffer2,...] portion of
            the message.
        ident : ident or list
            A single ident or a list of idents to use in sending.
        """
        to_send = []
        if isinstance(ident, bytes):
            ident = [ident]
        if ident is not None:
            to_send.extend(ident)

        to_send.append(DELIM)
        to_send.append(self.sign(msg_list))
        to_send.extend(msg_list)
        stream.send_multipart(msg_list, flags, copy=copy)

    def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
        """Receive and unpack a message.

        Parameters
        ----------
        socket : ZMQStream or Socket
            The socket or stream to use in receiving.

        Returns
        -------
        [idents], msg
            [idents] is a list of idents and msg is a nested message dict of
            same format as self.msg returns.
        """
        if isinstance(socket, ZMQStream):
            socket = socket.socket
        try:
            msg_list = socket.recv_multipart(mode, copy=copy)
        except zmq.ZMQError as e:
            if e.errno == zmq.EAGAIN:
                # We can convert EAGAIN to None as we know in this case
                # recv_multipart won't return None.
                return None, None
            else:
                raise
        # split multipart message into identity list and message dict
        # invalid large messages can cause very expensive string comparisons
        idents, msg_list = self.feed_identities(msg_list, copy)
        try:
            return idents, self.unserialize(msg_list,
                                            content=content,
                                            copy=copy)
        except Exception as e:
            # TODO: handle it
            raise e

    def feed_identities(self, msg_list, copy=True):
        """Split the identities from the rest of the message.

        Feed until DELIM is reached, then return the prefix as idents and
        remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
        but that would be silly.

        Parameters
        ----------
        msg_list : a list of Message or bytes objects
            The message to be split.
        copy : bool
            flag determining whether the arguments are bytes or Messages

        Returns
        -------
        (idents, msg_list) : two lists
            idents will always be a list of bytes, each of which is a ZMQ
            identity. msg_list will be a list of bytes or zmq.Messages of the
            form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
            should be unpackable/unserializable via self.unserialize at this
            point.
        """
        if copy:
            idx = msg_list.index(DELIM)
            return msg_list[:idx], msg_list[idx + 1:]
        else:
            failed = True
            for idx, m in enumerate(msg_list):
                if m.bytes == DELIM:
                    failed = False
                    break
            if failed:
                raise ValueError("DELIM not in msg_list")
            idents, msg_list = msg_list[:idx], msg_list[idx + 1:]
            return [m.bytes for m in idents], msg_list

    def unserialize(self, msg_list, content=True, copy=True):
        """Unserialize a msg_list to a nested message dict.

        This is roughly the inverse of serialize. The serialize/unserialize
        methods work with full message lists, whereas pack/unpack work with
        the individual message parts in the message list.

        Parameters:
        -----------
        msg_list : list of bytes or Message objects
            The list of message parts of the form [HMAC,p_header,p_parent,
            p_content,buffer1,buffer2,...].
        content : bool (True)
            Whether to unpack the content dict (True), or leave it packed
            (False).
        copy : bool (True)
            Whether to return the bytes (True), or the non-copying Message
            object in each place (False).

        Returns
        -------
        msg : dict
            The nested message dict with top-level keys [header, parent_header,
            content, buffers].
        """
        minlen = 4
        message = {}
        if not copy:
            for i in range(minlen):
                msg_list[i] = msg_list[i].bytes
        if self.auth is not None:
            signature = msg_list[0]
            if not signature:
                raise ValueError("Unsigned Message")
            if signature in self.digest_history:
                raise ValueError("Duplicate Signature: %r" % signature)
            self.digest_history.add(signature)
            check = self.sign(msg_list[1:4])
            if not signature == check:
                raise ValueError("Invalid Signature: %r" % signature)
        if not len(msg_list) >= minlen:
            raise TypeError(
                "malformed message, must have at least %i elements" % minlen)
        header = self.unpack(msg_list[1])
        message['header'] = header
        message['msg_id'] = header['msg_id']
        message['msg_type'] = header['msg_type']
        message['parent_header'] = self.unpack(msg_list[2])
        if content:
            message['content'] = self.unpack(msg_list[3])
        else:
            message['content'] = msg_list[3]

        message['buffers'] = msg_list[4:]
        return message
Esempio n. 7
0
class AliasManager(Configurable):

    default_aliases = List(default_aliases(), config=True)
    user_aliases = List(default_value=[], config=True)
    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')

    def __init__(self, shell=None, **kwargs):
        super(AliasManager, self).__init__(shell=shell, **kwargs)
        # For convenient access
        self.linemagics = self.shell.magics_manager.magics['line']
        self.init_aliases()

    def init_aliases(self):
        # Load default & user aliases
        for name, cmd in self.default_aliases + self.user_aliases:
            self.soft_define_alias(name, cmd)

    @property
    def aliases(self):
        return [(n, func.cmd) for (n, func) in self.linemagics.items()
                            if isinstance(func, Alias)]

    def soft_define_alias(self, name, cmd):
        """Define an alias, but don't raise on an AliasError."""
        try:
            self.define_alias(name, cmd)
        except AliasError as e:
            error("Invalid alias: %s" % e)

    def define_alias(self, name, cmd):
        """Define a new alias after validating it.

        This will raise an :exc:`AliasError` if there are validation
        problems.
        """
        caller = Alias(shell=self.shell, name=name, cmd=cmd)
        self.shell.magics_manager.register_function(caller, magic_kind='line',
                                                    magic_name=name)

    def get_alias(self, name):
        """Return an alias, or None if no alias by that name exists."""
        aname = self.linemagics.get(name, None)
        return aname if isinstance(aname, Alias) else None

    def is_alias(self, name):
        """Return whether or not a given name has been defined as an alias"""
        return self.get_alias(name) is not None

    def undefine_alias(self, name):
        if self.is_alias(name):
            del self.linemagics[name]
        else:
            raise ValueError('%s is not an alias' % name)

    def clear_aliases(self):
        for name, cmd in self.aliases:
            self.undefine_alias(name)

    def retrieve_alias(self, name):
        """Retrieve the command to which an alias expands."""
        caller = self.get_alias(name)
        if caller:
            return caller.cmd
        else:
            raise ValueError('%s is not an alias' % name)
class MagicsManager(Configurable):
    """Object that handles all magic-related functionality for IPython.
    """
    # Non-configurable class attributes

    # A two-level dict, first keyed by magic type, then by magic function, and
    # holding the actual callable object as value.  This is the dict used for
    # magic function dispatch
    magics = Dict

    # A registry of the original objects that we've been given holding magics.
    registry = Dict

    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')

    auto_magic = Bool(
        True,
        config=True,
        help=
        "Automatically call line magics without requiring explicit % prefix")

    _auto_status = [
        'Automagic is OFF, % prefix IS needed for line magics.',
        'Automagic is ON, % prefix IS NOT needed for line magics.'
    ]

    user_magics = Instance('IPython.core.magics.UserMagics')

    def __init__(self, shell=None, config=None, user_magics=None, **traits):

        super(MagicsManager, self).__init__(shell=shell,
                                            config=config,
                                            user_magics=user_magics,
                                            **traits)
        self.magics = dict(line={}, cell={})
        # Let's add the user_magics to the registry for uniformity, so *all*
        # registered magic containers can be found there.
        self.registry[user_magics.__class__.__name__] = user_magics

    def auto_status(self):
        """Return descriptive string with automagic status."""
        return self._auto_status[self.auto_magic]

    def lsmagic_info(self):
        magic_list = []
        for m_type in self.magics:
            for m_name, mgc in list(self.magics[m_type].items()):
                try:
                    magic_list.append({
                        'name': m_name,
                        'type': m_type,
                        'class': mgc.__self__.__class__.__name__
                    })
                except AttributeError:
                    magic_list.append({
                        'name': m_name,
                        'type': m_type,
                        'class': 'Other'
                    })
        return magic_list

    def lsmagic(self):
        """Return a dict of currently available magic functions.

        The return dict has the keys 'line' and 'cell', corresponding to the
        two types of magics we support.  Each value is a list of names.
        """
        return self.magics

    def lsmagic_docs(self, brief=False, missing=''):
        """Return dict of documentation of magic functions.

        The return dict has the keys 'line' and 'cell', corresponding to the
        two types of magics we support. Each value is a dict keyed by magic
        name whose value is the function docstring. If a docstring is
        unavailable, the value of `missing` is used instead.

        If brief is True, only the first line of each docstring will be returned.
        """
        docs = {}
        for m_type in self.magics:
            m_docs = {}
            for m_name, m_func in self.magics[m_type].items():
                if m_func.__doc__:
                    if brief:
                        m_docs[m_name] = m_func.__doc__.split('\n', 1)[0]
                    else:
                        m_docs[m_name] = m_func.__doc__.rstrip()
                else:
                    m_docs[m_name] = missing
            docs[m_type] = m_docs
        return docs

    def register(self, *magic_objects):
        """Register one or more instances of Magics.

        Take one or more classes or instances of classes that subclass the main 
        `core.Magic` class, and register them with IPython to use the magic
        functions they provide.  The registration process will then ensure that
        any methods that have decorated to provide line and/or cell magics will
        be recognized with the `%x`/`%%x` syntax as a line/cell magic
        respectively.

        If classes are given, they will be instantiated with the default
        constructor.  If your classes need a custom constructor, you should
        instanitate them first and pass the instance.

        The provided arguments can be an arbitrary mix of classes and instances.

        Parameters
        ----------
        magic_objects : one or more classes or instances
        """
        # Start by validating them to ensure they have all had their magic
        # methods registered at the instance level
        for m in magic_objects:
            if not m.registered:
                raise ValueError("Class of magics %r was constructed without "
                                 "the @register_magics class decorator")
            if type(m) in (type, MetaHasTraits):
                # If we're given an uninstantiated class
                m = m(shell=self.shell)

            # Now that we have an instance, we can register it and update the
            # table of callables
            self.registry[m.__class__.__name__] = m
            for mtype in magic_kinds:
                self.magics[mtype].update(m.magics[mtype])

    def register_function(self, func, magic_kind='line', magic_name=None):
        """Expose a standalone function as magic function for IPython.

        This will create an IPython magic (line, cell or both) from a
        standalone function.  The functions should have the following
        signatures: 

        * For line magics: `def f(line)`
        * For cell magics: `def f(line, cell)`
        * For a function that does both: `def f(line, cell=None)`

        In the latter case, the function will be called with `cell==None` when
        invoked as `%f`, and with cell as a string when invoked as `%%f`.

        Parameters
        ----------
        func : callable
          Function to be registered as a magic.

        magic_kind : str
          Kind of magic, one of 'line', 'cell' or 'line_cell'

        magic_name : optional str
          If given, the name the magic will have in the IPython namespace.  By
          default, the name of the function itself is used.
        """

        # Create the new method in the user_magics and register it in the
        # global table
        validate_type(magic_kind)
        magic_name = func.__name__ if magic_name is None else magic_name
        setattr(self.user_magics, magic_name, func)
        record_magic(self.magics, magic_kind, magic_name, func)

    def define_magic(self, name, func):
        """[Deprecated] Expose own function as magic function for IPython.

        Example::

            def foo_impl(self, parameter_s=''):
                'My very own magic!. (Use docstrings, IPython reads them).'
                print 'Magic function. Passed parameter is between < >:'
                print '<%s>' % parameter_s
                print 'The self object is:', self

            ip.define_magic('foo',foo_impl)
        """
        meth = types.MethodType(func, self.user_magics)
        setattr(self.user_magics, name, meth)
        record_magic(self.magics, 'line', name, meth)
Esempio n. 9
0
class FrontendWidget(HistoryConsoleWidget, BaseFrontendMixin):
    """ A Qt frontend for a generic Python kernel.
    """

    # The text to show when the kernel is (re)started.
    banner = Unicode()

    # An option and corresponding signal for overriding the default kernel
    # interrupt behavior.
    custom_interrupt = Bool(False)
    custom_interrupt_requested = QtCore.Signal()

    # An option and corresponding signals for overriding the default kernel
    # restart behavior.
    custom_restart = Bool(False)
    custom_restart_kernel_died = QtCore.Signal(float)
    custom_restart_requested = QtCore.Signal()

    # Whether to automatically show calltips on open-parentheses.
    enable_calltips = Bool(
        True,
        config=True,
        help="Whether to draw information calltips on open-parentheses.")

    clear_on_kernel_restart = Bool(
        True,
        config=True,
        help="Whether to clear the console when the kernel is restarted")

    confirm_restart = Bool(
        True,
        config=True,
        help="Whether to ask for user confirmation when restarting kernel")

    # Emitted when a user visible 'execute_request' has been submitted to the
    # kernel from the FrontendWidget. Contains the code to be executed.
    executing = QtCore.Signal(object)

    # Emitted when a user-visible 'execute_reply' has been received from the
    # kernel and processed by the FrontendWidget. Contains the response message.
    executed = QtCore.Signal(object)

    # Emitted when an exit request has been received from the kernel.
    exit_requested = QtCore.Signal(object)

    # Protected class variables.
    _transform_prompt = staticmethod(transform_classic_prompt)
    _CallTipRequest = namedtuple('_CallTipRequest', ['id', 'pos'])
    _CompletionRequest = namedtuple('_CompletionRequest', ['id', 'pos'])
    _ExecutionRequest = namedtuple('_ExecutionRequest', ['id', 'kind'])
    _input_splitter_class = InputSplitter
    _local_kernel = False
    _highlighter = Instance(FrontendHighlighter)

    #---------------------------------------------------------------------------
    # 'object' interface
    #---------------------------------------------------------------------------

    def __init__(self, *args, **kw):
        super(FrontendWidget, self).__init__(*args, **kw)
        # FIXME: remove this when PySide min version is updated past 1.0.7
        # forcefully disable calltips if PySide is < 1.0.7, because they crash
        if qt.QT_API == qt.QT_API_PYSIDE:
            import PySide
            if PySide.__version_info__ < (1, 0, 7):
                self.log.warn(
                    "PySide %s < 1.0.7 detected, disabling calltips" %
                    PySide.__version__)
                self.enable_calltips = False

        # FrontendWidget protected variables.
        self._bracket_matcher = BracketMatcher(self._control)
        self._call_tip_widget = CallTipWidget(self._control)
        self._completion_lexer = CompletionLexer(PythonLexer())
        self._copy_raw_action = QtGui.QAction('Copy (Raw Text)', None)
        self._hidden = False
        self._highlighter = FrontendHighlighter(self)
        self._input_splitter = self._input_splitter_class(input_mode='cell')
        self._kernel_manager = None
        self._request_info = {}
        self._request_info['execute'] = {}
        self._callback_dict = {}

        # Configure the ConsoleWidget.
        self.tab_width = 4
        self._set_continuation_prompt('... ')

        # Configure the CallTipWidget.
        self._call_tip_widget.setFont(self.font)
        self.font_changed.connect(self._call_tip_widget.setFont)

        # Configure actions.
        action = self._copy_raw_action
        key = QtCore.Qt.CTRL | QtCore.Qt.SHIFT | QtCore.Qt.Key_C
        action.setEnabled(False)
        action.setShortcut(QtGui.QKeySequence(key))
        action.setShortcutContext(QtCore.Qt.WidgetWithChildrenShortcut)
        action.triggered.connect(self.copy_raw)
        self.copy_available.connect(action.setEnabled)
        self.addAction(action)

        # Connect signal handlers.
        document = self._control.document()
        document.contentsChange.connect(self._document_contents_change)

        # Set flag for whether we are connected via localhost.
        self._local_kernel = kw.get('local_kernel',
                                    FrontendWidget._local_kernel)

    #---------------------------------------------------------------------------
    # 'ConsoleWidget' public interface
    #---------------------------------------------------------------------------

    def copy(self):
        """ Copy the currently selected text to the clipboard, removing prompts.
        """
        if self._page_control is not None and self._page_control.hasFocus():
            self._page_control.copy()
        elif self._control.hasFocus():
            text = self._control.textCursor().selection().toPlainText()
            if text:
                lines = map(self._transform_prompt, text.splitlines())
                text = '\n'.join(lines)
                QtGui.QApplication.clipboard().setText(text)
        else:
            self.log.debug("frontend widget : unknown copy target")

    #---------------------------------------------------------------------------
    # 'ConsoleWidget' abstract interface
    #---------------------------------------------------------------------------

    def _is_complete(self, source, interactive):
        """ Returns whether 'source' can be completely processed and a new
            prompt created. When triggered by an Enter/Return key press,
            'interactive' is True; otherwise, it is False.
        """
        complete = self._input_splitter.push(source)
        if interactive:
            complete = not self._input_splitter.push_accepts_more()
        return complete

    def _execute(self, source, hidden):
        """ Execute 'source'. If 'hidden', do not show any output.

        See parent class :meth:`execute` docstring for full details.
        """
        msg_id = self.kernel_manager.shell_channel.execute(source, hidden)
        self._request_info['execute'][msg_id] = self._ExecutionRequest(
            msg_id, 'user')
        self._hidden = hidden
        if not hidden:
            self.executing.emit(source)

    def _prompt_started_hook(self):
        """ Called immediately after a new prompt is displayed.
        """
        if not self._reading:
            self._highlighter.highlighting_on = True

    def _prompt_finished_hook(self):
        """ Called immediately after a prompt is finished, i.e. when some input
            will be processed and a new prompt displayed.
        """
        # Flush all state from the input splitter so the next round of
        # reading input starts with a clean buffer.
        self._input_splitter.reset()

        if not self._reading:
            self._highlighter.highlighting_on = False

    def _tab_pressed(self):
        """ Called when the tab key is pressed. Returns whether to continue
            processing the event.
        """
        # Perform tab completion if:
        # 1) The cursor is in the input buffer.
        # 2) There is a non-whitespace character before the cursor.
        text = self._get_input_buffer_cursor_line()
        if text is None:
            return False
        complete = bool(text[:self._get_input_buffer_cursor_column()].strip())
        if complete:
            self._complete()
        return not complete

    #---------------------------------------------------------------------------
    # 'ConsoleWidget' protected interface
    #---------------------------------------------------------------------------

    def _context_menu_make(self, pos):
        """ Reimplemented to add an action for raw copy.
        """
        menu = super(FrontendWidget, self)._context_menu_make(pos)
        for before_action in menu.actions():
            if before_action.shortcut().matches(QtGui.QKeySequence.Paste) == \
                    QtGui.QKeySequence.ExactMatch:
                menu.insertAction(before_action, self._copy_raw_action)
                break
        return menu

    def request_interrupt_kernel(self):
        if self._executing:
            self.interrupt_kernel()

    def request_restart_kernel(self):
        message = 'Are you sure you want to restart the kernel?'
        self.restart_kernel(message, now=False)

    def _event_filter_console_keypress(self, event):
        """ Reimplemented for execution interruption and smart backspace.
        """
        key = event.key()
        if self._control_key_down(event.modifiers(), include_command=False):

            if key == QtCore.Qt.Key_C and self._executing:
                self.request_interrupt_kernel()
                return True

            elif key == QtCore.Qt.Key_Period:
                self.request_restart_kernel()
                return True

        elif not event.modifiers() & QtCore.Qt.AltModifier:

            # Smart backspace: remove four characters in one backspace if:
            # 1) everything left of the cursor is whitespace
            # 2) the four characters immediately left of the cursor are spaces
            if key == QtCore.Qt.Key_Backspace:
                col = self._get_input_buffer_cursor_column()
                cursor = self._control.textCursor()
                if col > 3 and not cursor.hasSelection():
                    text = self._get_input_buffer_cursor_line()[:col]
                    if text.endswith('    ') and not text.strip():
                        cursor.movePosition(QtGui.QTextCursor.Left,
                                            QtGui.QTextCursor.KeepAnchor, 4)
                        cursor.removeSelectedText()
                        return True

        return super(FrontendWidget,
                     self)._event_filter_console_keypress(event)

    def _insert_continuation_prompt(self, cursor):
        """ Reimplemented for auto-indentation.
        """
        super(FrontendWidget, self)._insert_continuation_prompt(cursor)
        cursor.insertText(' ' * self._input_splitter.indent_spaces)

    #---------------------------------------------------------------------------
    # 'BaseFrontendMixin' abstract interface
    #---------------------------------------------------------------------------

    def _handle_complete_reply(self, rep):
        """ Handle replies for tab completion.
        """
        self.log.debug("complete: %s", rep.get('content', ''))
        cursor = self._get_cursor()
        info = self._request_info.get('complete')
        if info and info.id == rep['parent_header']['msg_id'] and \
                info.pos == cursor.position():
            text = '.'.join(self._get_context())
            cursor.movePosition(QtGui.QTextCursor.Left, n=len(text))
            self._complete_with_items(cursor, rep['content']['matches'])

    def _silent_exec_callback(self, expr, callback):
        """Silently execute `expr` in the kernel and call `callback` with reply

        the `expr` is evaluated silently in the kernel (without) output in
        the frontend. Call `callback` with the
        `repr <http://docs.python.org/library/functions.html#repr> `_ as first argument

        Parameters
        ----------
        expr : string
            valid string to be executed by the kernel.
        callback : function
            function accepting one argument, as a string. The string will be
            the `repr` of the result of evaluating `expr`

        The `callback` is called with the `repr()` of the result of `expr` as
        first argument. To get the object, do `eval()` on the passed value.

        See Also
        --------
        _handle_exec_callback : private method, deal with calling callback with reply

        """

        # generate uuid, which would be used as an indication of whether or
        # not the unique request originated from here (can use msg id ?)
        local_uuid = str(uuid.uuid1())
        msg_id = self.kernel_manager.shell_channel.execute(
            '', silent=True, user_expressions={local_uuid: expr})
        self._callback_dict[local_uuid] = callback
        self._request_info['execute'][msg_id] = self._ExecutionRequest(
            msg_id, 'silent_exec_callback')

    def _handle_exec_callback(self, msg):
        """Execute `callback` corresponding to `msg` reply, after ``_silent_exec_callback``

        Parameters
        ----------
        msg : raw message send by the kernel containing an `user_expressions`
                and having a 'silent_exec_callback' kind.

        Notes
        -----
        This function will look for a `callback` associated with the
        corresponding message id. Association has been made by
        `_silent_exec_callback`. `callback` is then called with the `repr()`
        of the value of corresponding `user_expressions` as argument.
        `callback` is then removed from the known list so that any message
        coming again with the same id won't trigger it.

        """

        user_exp = msg['content'].get('user_expressions')
        if not user_exp:
            return
        for expression in user_exp:
            if expression in self._callback_dict:
                self._callback_dict.pop(expression)(user_exp[expression])

    def _handle_execute_reply(self, msg):
        """ Handles replies for code execution.
        """
        self.log.debug("execute: %s", msg.get('content', ''))
        msg_id = msg['parent_header']['msg_id']
        info = self._request_info['execute'].get(msg_id)
        # unset reading flag, because if execute finished, raw_input can't
        # still be pending.
        self._reading = False
        if info and info.kind == 'user' and not self._hidden:
            # Make sure that all output from the SUB channel has been processed
            # before writing a new prompt.
            self.kernel_manager.sub_channel.flush()

            # Reset the ANSI style information to prevent bad text in stdout
            # from messing up our colors. We're not a true terminal so we're
            # allowed to do this.
            if self.ansi_codes:
                self._ansi_processor.reset_sgr()

            content = msg['content']
            status = content['status']
            if status == 'ok':
                self._process_execute_ok(msg)
            elif status == 'error':
                self._process_execute_error(msg)
            elif status == 'aborted':
                self._process_execute_abort(msg)

            self._show_interpreter_prompt_for_reply(msg)
            self.executed.emit(msg)
            self._request_info['execute'].pop(msg_id)
        elif info and info.kind == 'silent_exec_callback' and not self._hidden:
            self._handle_exec_callback(msg)
            self._request_info['execute'].pop(msg_id)
        else:
            super(FrontendWidget, self)._handle_execute_reply(msg)

    def _handle_input_request(self, msg):
        """ Handle requests for raw_input.
        """
        self.log.debug("input: %s", msg.get('content', ''))
        if self._hidden:
            raise RuntimeError(
                'Request for raw input during hidden execution.')

        # Make sure that all output from the SUB channel has been processed
        # before entering readline mode.
        self.kernel_manager.sub_channel.flush()

        def callback(line):
            self.kernel_manager.stdin_channel.input(line)

        if self._reading:
            self.log.debug(
                "Got second input request, assuming first was interrupted.")
            self._reading = False
        self._readline(msg['content']['prompt'], callback=callback)

    def _handle_kernel_died(self, since_last_heartbeat):
        """ Handle the kernel's death by asking if the user wants to restart.
        """
        self.log.debug("kernel died: %s", since_last_heartbeat)
        if self.custom_restart:
            self.custom_restart_kernel_died.emit(since_last_heartbeat)
        else:
            message = 'The kernel heartbeat has been inactive for %.2f ' \
                'seconds. Do you want to restart the kernel? You may ' \
                'first want to check the network connection.' % \
                since_last_heartbeat
            self.restart_kernel(message, now=True)

    def _handle_object_info_reply(self, rep):
        """ Handle replies for call tips.
        """
        self.log.debug("oinfo: %s", rep.get('content', ''))
        cursor = self._get_cursor()
        info = self._request_info.get('call_tip')
        if info and info.id == rep['parent_header']['msg_id'] and \
                info.pos == cursor.position():
            # Get the information for a call tip.  For now we format the call
            # line as string, later we can pass False to format_call and
            # syntax-highlight it ourselves for nicer formatting in the
            # calltip.
            content = rep['content']
            # if this is from pykernel, 'docstring' will be the only key
            if content.get('ismagic', False):
                # Don't generate a call-tip for magics. Ideally, we should
                # generate a tooltip, but not on ( like we do for actual
                # callables.
                call_info, doc = None, None
            else:
                call_info, doc = call_tip(content, format_call=True)
            if call_info or doc:
                self._call_tip_widget.show_call_info(call_info, doc)

    def _handle_pyout(self, msg):
        """ Handle display hook output.
        """
        self.log.debug("pyout: %s", msg.get('content', ''))
        if not self._hidden and self._is_from_this_session(msg):
            text = msg['content']['data']
            self._append_plain_text(text + '\n', before_prompt=True)

    def _handle_stream(self, msg):
        """ Handle stdout, stderr, and stdin.
        """
        self.log.debug("stream: %s", msg.get('content', ''))
        if not self._hidden and self._is_from_this_session(msg):
            # Most consoles treat tabs as being 8 space characters. Convert tabs
            # to spaces so that output looks as expected regardless of this
            # widget's tab width.
            text = msg['content']['data'].expandtabs(8)

            self._append_plain_text(text, before_prompt=True)
            self._control.moveCursor(QtGui.QTextCursor.End)

    def _handle_shutdown_reply(self, msg):
        """ Handle shutdown signal, only if from other console.
        """
        self.log.debug("shutdown: %s", msg.get('content', ''))
        if not self._hidden and not self._is_from_this_session(msg):
            if self._local_kernel:
                if not msg['content']['restart']:
                    self.exit_requested.emit(self)
                else:
                    # we just got notified of a restart!
                    time.sleep(0.25)  # wait 1/4 sec to reset
                    # lest the request for a new prompt
                    # goes to the old kernel
                    self.reset()
            else:  # remote kernel, prompt on Kernel shutdown/reset
                title = self.window().windowTitle()
                if not msg['content']['restart']:
                    reply = QtGui.QMessageBox.question(
                        self, title, "Kernel has been shutdown permanently. "
                        "Close the Console?", QtGui.QMessageBox.Yes,
                        QtGui.QMessageBox.No)
                    if reply == QtGui.QMessageBox.Yes:
                        self.exit_requested.emit(self)
                else:
                    # XXX: remove message box in favor of using the
                    # clear_on_kernel_restart setting?
                    reply = QtGui.QMessageBox.question(
                        self, title,
                        "Kernel has been reset. Clear the Console?",
                        QtGui.QMessageBox.Yes, QtGui.QMessageBox.No)
                    if reply == QtGui.QMessageBox.Yes:
                        time.sleep(0.25)  # wait 1/4 sec to reset
                        # lest the request for a new prompt
                        # goes to the old kernel
                        self.reset()

    def _started_channels(self):
        """ Called when the KernelManager channels have started listening or
            when the frontend is assigned an already listening KernelManager.
        """
        self.reset(clear=True)

    #---------------------------------------------------------------------------
    # 'FrontendWidget' public interface
    #---------------------------------------------------------------------------

    def copy_raw(self):
        """ Copy the currently selected text to the clipboard without attempting
            to remove prompts or otherwise alter the text.
        """
        self._control.copy()

    def execute_file(self, path, hidden=False):
        """ Attempts to execute file with 'path'. If 'hidden', no output is
            shown.
        """
        self.execute('execfile(%r)' % path, hidden=hidden)

    def interrupt_kernel(self):
        """ Attempts to interrupt the running kernel.
        
        Also unsets _reading flag, to avoid runtime errors
        if raw_input is called again.
        """
        if self.custom_interrupt:
            self._reading = False
            self.custom_interrupt_requested.emit()
        elif self.kernel_manager.has_kernel:
            self._reading = False
            self.kernel_manager.interrupt_kernel()
        else:
            self._append_plain_text('Kernel process is either remote or '
                                    'unspecified. Cannot interrupt.\n')

    def reset(self, clear=False):
        """ Resets the widget to its initial state if ``clear`` parameter or
        ``clear_on_kernel_restart`` configuration setting is True, otherwise
        prints a visual indication of the fact that the kernel restarted, but
        does not clear the traces from previous usage of the kernel before it
        was restarted.  With ``clear=True``, it is similar to ``%clear``, but
        also re-writes the banner and aborts execution if necessary.
        """
        if self._executing:
            self._executing = False
            self._request_info['execute'] = {}
        self._reading = False
        self._highlighter.highlighting_on = False

        if self.clear_on_kernel_restart or clear:
            self._control.clear()
            self._append_plain_text(self.banner)
        else:
            self._append_plain_text("# restarting kernel...")
            self._append_html("<hr><br>")
            # XXX: Reprinting the full banner may be too much, but once #1680 is
            # addressed, that will mitigate it.
            #self._append_plain_text(self.banner)
        # update output marker for stdout/stderr, so that startup
        # messages appear after banner:
        self._append_before_prompt_pos = self._get_cursor().position()
        self._show_interpreter_prompt()

    def restart_kernel(self, message, now=False):
        """ Attempts to restart the running kernel.
        """
        # FIXME: now should be configurable via a checkbox in the dialog.  Right
        # now at least the heartbeat path sets it to True and the manual restart
        # to False.  But those should just be the pre-selected states of a
        # checkbox that the user could override if so desired.  But I don't know
        # enough Qt to go implementing the checkbox now.

        if self.custom_restart:
            self.custom_restart_requested.emit()

        elif self.kernel_manager.has_kernel:
            # Pause the heart beat channel to prevent further warnings.
            self.kernel_manager.hb_channel.pause()

            # Prompt the user to restart the kernel. Un-pause the heartbeat if
            # they decline. (If they accept, the heartbeat will be un-paused
            # automatically when the kernel is restarted.)
            if self.confirm_restart:
                buttons = QtGui.QMessageBox.Yes | QtGui.QMessageBox.No
                result = QtGui.QMessageBox.question(self, 'Restart kernel?',
                                                    message, buttons)
                do_restart = result == QtGui.QMessageBox.Yes
            else:
                # confirm_restart is False, so we don't need to ask user
                # anything, just do the restart
                do_restart = True
            if do_restart:
                try:
                    self.kernel_manager.restart_kernel(now=now)
                except RuntimeError:
                    self._append_plain_text(
                        'Kernel started externally. '
                        'Cannot restart.\n',
                        before_prompt=True)
                else:
                    self.reset()
            else:
                self.kernel_manager.hb_channel.unpause()

        else:
            self._append_plain_text(
                'Kernel process is either remote or '
                'unspecified. Cannot restart.\n',
                before_prompt=True)

    #---------------------------------------------------------------------------
    # 'FrontendWidget' protected interface
    #---------------------------------------------------------------------------

    def _call_tip(self):
        """ Shows a call tip, if appropriate, at the current cursor location.
        """
        # Decide if it makes sense to show a call tip
        if not self.enable_calltips:
            return False
        cursor = self._get_cursor()
        cursor.movePosition(QtGui.QTextCursor.Left)
        if cursor.document().characterAt(cursor.position()) != '(':
            return False
        context = self._get_context(cursor)
        if not context:
            return False

        # Send the metadata request to the kernel
        name = '.'.join(context)
        msg_id = self.kernel_manager.shell_channel.object_info(name)
        pos = self._get_cursor().position()
        self._request_info['call_tip'] = self._CallTipRequest(msg_id, pos)
        return True

    def _complete(self):
        """ Performs completion at the current cursor location.
        """
        context = self._get_context()
        if context:
            # Send the completion request to the kernel
            msg_id = self.kernel_manager.shell_channel.complete(
                '.'.join(context),  # text
                self._get_input_buffer_cursor_line(),  # line
                self._get_input_buffer_cursor_column(),  # cursor_pos
                self.input_buffer)  # block
            pos = self._get_cursor().position()
            info = self._CompletionRequest(msg_id, pos)
            self._request_info['complete'] = info

    def _get_context(self, cursor=None):
        """ Gets the context for the specified cursor (or the current cursor
            if none is specified).
        """
        if cursor is None:
            cursor = self._get_cursor()
        cursor.movePosition(QtGui.QTextCursor.StartOfBlock,
                            QtGui.QTextCursor.KeepAnchor)
        text = cursor.selection().toPlainText()
        return self._completion_lexer.get_context(text)

    def _process_execute_abort(self, msg):
        """ Process a reply for an aborted execution request.
        """
        self._append_plain_text("ERROR: execution aborted\n")

    def _process_execute_error(self, msg):
        """ Process a reply for an execution request that resulted in an error.
        """
        content = msg['content']
        # If a SystemExit is passed along, this means exit() was called - also
        # all the ipython %exit magic syntax of '-k' to be used to keep
        # the kernel running
        if content['ename'] == 'SystemExit':
            keepkernel = content['evalue'] == '-k' or content[
                'evalue'] == 'True'
            self._keep_kernel_on_exit = keepkernel
            self.exit_requested.emit(self)
        else:
            traceback = ''.join(content['traceback'])
            self._append_plain_text(traceback)

    def _process_execute_ok(self, msg):
        """ Process a reply for a successful execution request.
        """
        payload = msg['content']['payload']
        for item in payload:
            if not self._process_execute_payload(item):
                warning = 'Warning: received unknown payload of type %s'
                print(warning % repr(item['source']))

    def _process_execute_payload(self, item):
        """ Process a single payload item from the list of payload items in an
            execution reply. Returns whether the payload was handled.
        """
        # The basic FrontendWidget doesn't handle payloads, as they are a
        # mechanism for going beyond the standard Python interpreter model.
        return False

    def _show_interpreter_prompt(self):
        """ Shows a prompt for the interpreter.
        """
        self._show_prompt('>>> ')

    def _show_interpreter_prompt_for_reply(self, msg):
        """ Shows a prompt for the interpreter given an 'execute_reply' message.
        """
        self._show_interpreter_prompt()

    #------ Signal handlers ----------------------------------------------------

    def _document_contents_change(self, position, removed, added):
        """ Called whenever the document's content changes. Display a call tip
            if appropriate.
        """
        # Calculate where the cursor should be *after* the change:
        position += added

        document = self._control.document()
        if position == self._get_cursor().position():
            self._call_tip()

    #------ Trait default initializers -----------------------------------------

    def _banner_default(self):
        """ Returns the standard Python banner.
        """
        banner = 'Python %s on %s\nType "help", "copyright", "credits" or ' \
            '"license" for more information.'
        return banner % (sys.version, sys.platform)
Esempio n. 10
0
class BuiltinTrap(Configurable):

    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')

    def __init__(self, shell=None):
        super(BuiltinTrap, self).__init__(shell=shell, config=None)
        self._orig_builtins = {}
        # We define this to track if a single BuiltinTrap is nested.
        # Only turn off the trap when the outermost call to __exit__ is made.
        self._nested_level = 0
        self.shell = shell
        # builtins we always add - if set to HideBuiltin, they will just
        # be removed instead of being replaced by something else
        self.auto_builtins = {
            'exit': HideBuiltin,
            'quit': HideBuiltin,
            'get_ipython': self.shell.get_ipython,
        }
        # Recursive reload function
        try:
            from IPython.lib import deepreload
            if self.shell.deep_reload:
                self.auto_builtins['reload'] = deepreload.reload
            else:
                self.auto_builtins['dreload'] = deepreload.reload
        except ImportError:
            pass

    def __enter__(self):
        if self._nested_level == 0:
            self.activate()
        self._nested_level += 1
        # I return self, so callers can use add_builtin in a with clause.
        return self

    def __exit__(self, type, value, traceback):
        if self._nested_level == 1:
            self.deactivate()
        self._nested_level -= 1
        # Returning False will cause exceptions to propagate
        return False

    def add_builtin(self, key, value):
        """Add a builtin and save the original."""
        bdict = builtin_mod.__dict__
        orig = bdict.get(key, BuiltinUndefined)
        if value is HideBuiltin:
            if orig is not BuiltinUndefined:  #same as 'key in bdict'
                self._orig_builtins[key] = orig
                del bdict[key]
        else:
            self._orig_builtins[key] = orig
            bdict[key] = value

    def remove_builtin(self, key, orig):
        """Remove an added builtin and re-set the original."""
        if orig is BuiltinUndefined:
            del builtin_mod.__dict__[key]
        else:
            builtin_mod.__dict__[key] = orig

    def activate(self):
        """Store ipython references in the __builtin__ namespace."""

        add_builtin = self.add_builtin
        for name, func in iteritems(self.auto_builtins):
            add_builtin(name, func)

    def deactivate(self):
        """Remove any builtins which might have been added by add_builtins, or
        restore overwritten ones to their previous values."""
        remove_builtin = self.remove_builtin
        for key, val in iteritems(self._orig_builtins):
            remove_builtin(key, val)
        self._orig_builtins.clear()
        self._builtins_added = False
Esempio n. 11
0
class Widget(LoggingConfigurable):
    #-------------------------------------------------------------------------
    # Class attributes
    #-------------------------------------------------------------------------
    _widget_construction_callback = None
    widgets = {}

    @staticmethod
    def on_widget_constructed(callback):
        """Registers a callback to be called when a widget is constructed.

        The callback must have the following signature:
        callback(widget)"""
        Widget._widget_construction_callback = callback

    @staticmethod
    def _call_widget_constructed(widget):
        """Static method, called when a widget is constructed."""
        if Widget._widget_construction_callback is not None and callable(
                Widget._widget_construction_callback):
            Widget._widget_construction_callback(widget)

    #-------------------------------------------------------------------------
    # Traits
    #-------------------------------------------------------------------------
    _model_name = Unicode('WidgetModel',
                          help="""Name of the backbone model 
        registered in the front-end to create and sync this widget with.""")
    _view_module = Unicode(
        help="""A requirejs module in which to find _view_name.
        If empty, look in the global registry.""",
        sync=True)
    _view_name = Unicode(None,
                         allow_none=True,
                         help="""Default view registered in the front-end
        to use to represent the widget.""",
                         sync=True)
    comm = Instance('IPython.kernel.comm.Comm')

    msg_throttle = Int(3,
                       sync=True,
                       help="""Maximum number of msgs the 
        front-end can send before receiving an idle msg from the back-end.""")

    version = Int(0, sync=True, help="""Widget's version""")
    keys = List()

    def _keys_default(self):
        return [name for name in self.traits(sync=True)]

    _property_lock = Tuple((None, None))
    _send_state_lock = Int(0)
    _states_to_send = Set(allow_none=False)
    _display_callbacks = Instance(CallbackDispatcher, ())
    _msg_callbacks = Instance(CallbackDispatcher, ())

    #-------------------------------------------------------------------------
    # (Con/de)structor
    #-------------------------------------------------------------------------
    def __init__(self, **kwargs):
        """Public constructor"""
        self._model_id = kwargs.pop('model_id', None)
        super(Widget, self).__init__(**kwargs)

        Widget._call_widget_constructed(self)
        self.open()

    def __del__(self):
        """Object disposal"""
        self.close()

    #-------------------------------------------------------------------------
    # Properties
    #-------------------------------------------------------------------------

    def open(self):
        """Open a comm to the frontend if one isn't already open."""
        if self.comm is None:
            args = dict(target_name='ipython.widget',
                        data={'model_name': self._model_name})
            if self._model_id is not None:
                args['comm_id'] = self._model_id
            self.comm = Comm(**args)
            self._model_id = self.model_id

            self.comm.on_msg(self._handle_msg)
            Widget.widgets[self.model_id] = self

            # first update
            self.send_state()

    @property
    def model_id(self):
        """Gets the model id of this widget.

        If a Comm doesn't exist yet, a Comm will be created automagically."""
        return self.comm.comm_id

    #-------------------------------------------------------------------------
    # Methods
    #-------------------------------------------------------------------------

    def close(self):
        """Close method.

        Closes the underlying comm.
        When the comm is closed, all of the widget views are automatically
        removed from the front-end."""
        if self.comm is not None:
            Widget.widgets.pop(self.model_id, None)
            self.comm.close()
            self.comm = None

    def send_state(self, key=None):
        """Sends the widget state, or a piece of it, to the front-end.

        Parameters
        ----------
        key : unicode, or iterable (optional)
            A single property's name or iterable of property names to sync with the front-end.
        """
        self._send({"method": "update", "state": self.get_state(key=key)})

    def get_state(self, key=None):
        """Gets the widget state, or a piece of it.

        Parameters
        ----------
        key : unicode or iterable (optional)
            A single property's name or iterable of property names to get.
        """
        if key is None:
            keys = self.keys
        elif isinstance(key, string_types):
            keys = [key]
        elif isinstance(key, collections.Iterable):
            keys = key
        else:
            raise ValueError(
                "key must be a string, an iterable of keys, or None")
        state = {}
        for k in keys:
            f = self.trait_metadata(k, 'to_json', self._trait_to_json)
            value = getattr(self, k)
            state[k] = f(value)
        return state

    def set_state(self, sync_data):
        """Called when a state is received from the front-end."""
        for name in self.keys:
            if name in sync_data:
                json_value = sync_data[name]
                from_json = self.trait_metadata(name, 'from_json',
                                                self._trait_from_json)
                with self._lock_property(name, json_value):
                    setattr(self, name, from_json(json_value))

    def send(self, content):
        """Sends a custom msg to the widget model in the front-end.

        Parameters
        ----------
        content : dict
            Content of the message to send.
        """
        self._send({"method": "custom", "content": content})

    def on_msg(self, callback, remove=False):
        """(Un)Register a custom msg receive callback.

        Parameters
        ----------
        callback: callable
            callback will be passed two arguments when a message arrives::
            
                callback(widget, content)
            
        remove: bool
            True if the callback should be unregistered."""
        self._msg_callbacks.register_callback(callback, remove=remove)

    def on_displayed(self, callback, remove=False):
        """(Un)Register a widget displayed callback.

        Parameters
        ----------
        callback: method handler
            Must have a signature of::
            
                callback(widget, **kwargs)
            
            kwargs from display are passed through without modification.
        remove: bool
            True if the callback should be unregistered."""
        self._display_callbacks.register_callback(callback, remove=remove)

    #-------------------------------------------------------------------------
    # Support methods
    #-------------------------------------------------------------------------
    @contextmanager
    def _lock_property(self, key, value):
        """Lock a property-value pair.

        The value should be the JSON state of the property.

        NOTE: This, in addition to the single lock for all state changes, is 
        flawed.  In the future we may want to look into buffering state changes 
        back to the front-end."""
        self._property_lock = (key, value)
        try:
            yield
        finally:
            self._property_lock = (None, None)

    @contextmanager
    def hold_sync(self):
        """Hold syncing any state until the context manager is released"""
        # We increment a value so that this can be nested.  Syncing will happen when
        # all levels have been released.
        self._send_state_lock += 1
        try:
            yield
        finally:
            self._send_state_lock -= 1
            if self._send_state_lock == 0:
                self.send_state(self._states_to_send)
                self._states_to_send.clear()

    def _should_send_property(self, key, value):
        """Check the property lock (property_lock)"""
        to_json = self.trait_metadata(key, 'to_json', self._trait_to_json)
        if (key == self._property_lock[0]
                and to_json(value) == self._property_lock[1]):
            return False
        elif self._send_state_lock > 0:
            self._states_to_send.add(key)
            return False
        else:
            return True

    # Event handlers
    @_show_traceback
    def _handle_msg(self, msg):
        """Called when a msg is received from the front-end"""
        data = msg['content']['data']
        method = data['method']
        if not method in ['backbone', 'custom']:
            self.log.error(
                'Unknown front-end to back-end widget msg with method "%s"' %
                method)

        # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one.
        if method == 'backbone' and 'sync_data' in data:
            sync_data = data['sync_data']
            self.set_state(sync_data)  # handles all methods

        # Handle a custom msg from the front-end
        elif method == 'custom':
            if 'content' in data:
                self._handle_custom_msg(data['content'])

    def _handle_custom_msg(self, content):
        """Called when a custom msg is received."""
        self._msg_callbacks(self, content)

    def _notify_trait(self, name, old_value, new_value):
        """Called when a property has been changed."""
        # Trigger default traitlet callback machinery.  This allows any user
        # registered validation to be processed prior to allowing the widget
        # machinery to handle the state.
        LoggingConfigurable._notify_trait(self, name, old_value, new_value)

        # Send the state after the user registered callbacks for trait changes
        # have all fired (allows for user to validate values).
        if self.comm is not None and name in self.keys:
            # Make sure this isn't information that the front-end just sent us.
            if self._should_send_property(name, new_value):
                # Send new state to front-end
                self.send_state(key=name)

    def _handle_displayed(self, **kwargs):
        """Called when a view has been displayed for this widget instance"""
        self._display_callbacks(self, **kwargs)

    def _trait_to_json(self, x):
        """Convert a trait value to json

        Traverse lists/tuples and dicts and serialize their values as well.
        Replace any widgets with their model_id
        """
        if isinstance(x, dict):
            return {k: self._trait_to_json(v) for k, v in x.items()}
        elif isinstance(x, (list, tuple)):
            return [self._trait_to_json(v) for v in x]
        elif isinstance(x, Widget):
            return "IPY_MODEL_" + x.model_id
        else:
            return x  # Value must be JSON-able

    def _trait_from_json(self, x):
        """Convert json values to objects

        Replace any strings representing valid model id values to Widget references.
        """
        if isinstance(x, dict):
            return {k: self._trait_from_json(v) for k, v in x.items()}
        elif isinstance(x, (list, tuple)):
            return [self._trait_from_json(v) for v in x]
        elif isinstance(x, string_types) and x.startswith(
                'IPY_MODEL_') and x[10:] in Widget.widgets:
            # we want to support having child widgets at any level in a hierarchy
            # trusting that a widget UUID will not appear out in the wild
            return Widget.widgets[x[10:]]
        else:
            return x

    def _ipython_display_(self, **kwargs):
        """Called when `IPython.display.display` is called on the widget."""
        # Show view.
        if self._view_name is not None:
            self._send({"method": "display"})
            self._handle_displayed(**kwargs)

    def _send(self, msg):
        """Sends a message to the model in the front-end."""
        self.comm.send(msg)
Esempio n. 12
0
class InteractiveShellApp(Configurable):
    """A Mixin for applications that start InteractiveShell instances.
    
    Provides configurables for loading extensions and executing files
    as part of configuring a Shell environment.

    The following methods should be called by the :meth:`initialize` method
    of the subclass:

      - :meth:`init_path`
      - :meth:`init_shell` (to be implemented by the subclass)
      - :meth:`init_gui_pylab`
      - :meth:`init_extensions`
      - :meth:`init_code`
    """
    extensions = List(Unicode, config=True,
        help="A list of dotted module names of IPython extensions to load."
    )
    extra_extension = Unicode('', config=True,
        help="dotted module name of an IPython extension to load."
    )
    def _extra_extension_changed(self, name, old, new):
        if new:
            # add to self.extensions
            self.extensions.append(new)
    
    # Extensions that are always loaded (not configurable)
    default_extensions = List(Unicode, [u'storemagic'], config=False)
    
    hide_initial_ns = Bool(True, config=True,
        help="""Should variables loaded at startup (by startup files, exec_lines, etc.)
        be hidden from tools like %who?"""
    )

    exec_files = List(Unicode, config=True,
        help="""List of files to run at IPython startup."""
    )
    exec_PYTHONSTARTUP = Bool(True, config=True,
        help="""Run the file referenced by the PYTHONSTARTUP environment
        variable at IPython startup."""
    )
    file_to_run = Unicode('', config=True,
        help="""A file to be run""")

    exec_lines = List(Unicode, config=True,
        help="""lines of code to run at IPython startup."""
    )
    code_to_run = Unicode('', config=True,
        help="Execute the given command string."
    )
    module_to_run = Unicode('', config=True,
        help="Run the module as a script."
    )
    gui = CaselessStrEnum(gui_keys, config=True,
        help="Enable GUI event loop integration with any of {0}.".format(gui_keys)
    )
    matplotlib = CaselessStrEnum(backend_keys,
        config=True,
        help="""Configure matplotlib for interactive use with
        the default matplotlib backend."""
    )
    pylab = CaselessStrEnum(backend_keys,
        config=True,
        help="""Pre-load matplotlib and numpy for interactive use,
        selecting a particular matplotlib backend and loop integration.
        """
    )
    pylab_import_all = Bool(True, config=True,
        help="""If true, IPython will populate the user namespace with numpy, pylab, etc.
        and an ``import *`` is done from numpy and pylab, when using pylab mode.
        
        When False, pylab mode should not import any names into the user namespace.
        """
    )
    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
    
    user_ns = Instance(dict, args=None, allow_none=True)
    def _user_ns_changed(self, name, old, new):
        if self.shell is not None:
            self.shell.user_ns = new
            self.shell.init_user_ns()

    def init_path(self):
        """Add current working directory, '', to sys.path"""
        if sys.path[0] != '':
            sys.path.insert(0, '')

    def init_shell(self):
        raise NotImplementedError("Override in subclasses")

    def init_gui_pylab(self):
        """Enable GUI event loop integration, taking pylab into account."""
        enable = False
        shell = self.shell
        if self.pylab:
            enable = lambda key: shell.enable_pylab(key, import_all=self.pylab_import_all)
            key = self.pylab
        elif self.matplotlib:
            enable = shell.enable_matplotlib
            key = self.matplotlib
        elif self.gui:
            enable = shell.enable_gui
            key = self.gui
        
        if not enable:
            return
        
        try:
            r = enable(key)
        except ImportError:
            self.log.warn("Eventloop or matplotlib integration failed. Is matplotlib installed?")
            self.shell.showtraceback()
            return
        except Exception:
            self.log.warn("GUI event loop or pylab initialization failed")
            self.shell.showtraceback()
            return
            
        if isinstance(r, tuple):
            gui, backend = r[:2]
            self.log.info("Enabling GUI event loop integration, "
                      "eventloop=%s, matplotlib=%s", gui, backend)
            if key == "auto":
                print("Using matplotlib backend: %s" % backend)
        else:
            gui = r
            self.log.info("Enabling GUI event loop integration, "
                      "eventloop=%s", gui)

    def init_extensions(self):
        """Load all IPython extensions in IPythonApp.extensions.

        This uses the :meth:`ExtensionManager.load_extensions` to load all
        the extensions listed in ``self.extensions``.
        """
        try:
            self.log.debug("Loading IPython extensions...")
            extensions = self.default_extensions + self.extensions
            for ext in extensions:
                try:
                    self.log.info("Loading IPython extension: %s" % ext)
                    self.shell.extension_manager.load_extension(ext)
                except:
                    self.log.warn("Error in loading extension: %s" % ext +
                        "\nCheck your config files in %s" % self.profile_dir.location
                    )
                    self.shell.showtraceback()
        except:
            self.log.warn("Unknown error in loading extensions:")
            self.shell.showtraceback()

    def init_code(self):
        """run the pre-flight code, specified via exec_lines"""
        self._run_startup_files()
        self._run_exec_lines()
        self._run_exec_files()
        
        # Hide variables defined here from %who etc.
        if self.hide_initial_ns:
            self.shell.user_ns_hidden.update(self.shell.user_ns)
        
        # command-line execution (ipython -i script.py, ipython -m module)
        # should *not* be excluded from %whos
        self._run_cmd_line_code()
        self._run_module()
        
        # flush output, so itwon't be attached to the first cell
        sys.stdout.flush()
        sys.stderr.flush()

    def _run_exec_lines(self):
        """Run lines of code in IPythonApp.exec_lines in the user's namespace."""
        if not self.exec_lines:
            return
        try:
            self.log.debug("Running code from IPythonApp.exec_lines...")
            for line in self.exec_lines:
                try:
                    self.log.info("Running code in user namespace: %s" %
                                  line)
                    self.shell.run_cell(line, store_history=False)
                except:
                    self.log.warn("Error in executing line in user "
                                  "namespace: %s" % line)
                    self.shell.showtraceback()
        except:
            self.log.warn("Unknown error in handling IPythonApp.exec_lines:")
            self.shell.showtraceback()

    def _exec_file(self, fname):
        try:
            full_filename = filefind(fname, [u'.', self.ipython_dir])
        except IOError as e:
            self.log.warn("File not found: %r"%fname)
            return
        # Make sure that the running script gets a proper sys.argv as if it
        # were run from a system shell.
        save_argv = sys.argv
        sys.argv = [full_filename] + self.extra_args[1:]
        # protect sys.argv from potential unicode strings on Python 2:
        if not py3compat.PY3:
            sys.argv = [ py3compat.cast_bytes(a) for a in sys.argv ]
        try:
            if os.path.isfile(full_filename):
                self.log.info("Running file in user namespace: %s" %
                              full_filename)
                # Ensure that __file__ is always defined to match Python
                # behavior.
                with preserve_keys(self.shell.user_ns, '__file__'):
                    self.shell.user_ns['__file__'] = fname
                    if full_filename.endswith('.ipy'):
                        self.shell.safe_execfile_ipy(full_filename)
                    else:
                        # default to python, even without extension
                        self.shell.safe_execfile(full_filename,
                                                 self.shell.user_ns)
        finally:
            sys.argv = save_argv

    def _run_startup_files(self):
        """Run files from profile startup directory"""
        startup_dir = self.profile_dir.startup_dir
        startup_files = []
        if self.exec_PYTHONSTARTUP and not (self.file_to_run or self.code_to_run or self.module_to_run):
            if os.environ.get('PYTHONSTARTUP', False):
                startup_files.append(os.environ['PYTHONSTARTUP'])
        startup_files += glob.glob(os.path.join(startup_dir, '*.py'))
        startup_files += glob.glob(os.path.join(startup_dir, '*.ipy'))
        if not startup_files:
            return
        
        self.log.debug("Running startup files from %s...", startup_dir)
        try:
            for fname in sorted(startup_files):
                self._exec_file(fname)
        except:
            self.log.warn("Unknown error in handling startup files:")
            self.shell.showtraceback()

    def _run_exec_files(self):
        """Run files from IPythonApp.exec_files"""
        if not self.exec_files:
            return

        self.log.debug("Running files in IPythonApp.exec_files...")
        try:
            for fname in self.exec_files:
                self._exec_file(fname)
        except:
            self.log.warn("Unknown error in handling IPythonApp.exec_files:")
            self.shell.showtraceback()

    def _run_cmd_line_code(self):
        """Run code or file specified at the command-line"""
        if self.code_to_run:
            line = self.code_to_run
            try:
                self.log.info("Running code given at command line (c=): %s" %
                              line)
                self.shell.run_cell(line, store_history=False)
            except:
                self.log.warn("Error in executing line in user namespace: %s" %
                              line)
                self.shell.showtraceback()

        # Like Python itself, ignore the second if the first of these is present
        elif self.file_to_run:
            fname = self.file_to_run
            try:
                self._exec_file(fname)
            except:
                self.log.warn("Error in executing file in user namespace: %s" %
                              fname)
                self.shell.showtraceback()

    def _run_module(self):
        """Run module specified at the command-line."""
        if self.module_to_run:
            # Make sure that the module gets a proper sys.argv as if it were
            # run using `python -m`.
            save_argv = sys.argv
            sys.argv = [sys.executable] + self.extra_args
            try:
                self.shell.safe_run_module(self.module_to_run,
                                           self.shell.user_ns)
            finally:
                sys.argv = save_argv
Esempio n. 13
0
class BaseParallelApplication(BaseIPythonApplication):
    """The base Application for IPython.parallel apps
    
    Principle extensions to BaseIPyythonApplication:
    
    * work_dir
    * remote logging via pyzmq
    * IOLoop instance
    """

    crash_handler_class = ParallelCrashHandler

    def _log_level_default(self):
        # temporarily override default_log_level to INFO
        return logging.INFO

    work_dir = Unicode(os.getcwdu(),
                       config=True,
                       help='Set the working dir for the process.')

    def _work_dir_changed(self, name, old, new):
        self.work_dir = unicode(expand_path(new))

    log_to_file = Bool(config=True, help="whether to log to a file")

    clean_logs = Bool(False,
                      config=True,
                      help="whether to cleanup old logfiles before starting")

    log_url = Unicode('',
                      config=True,
                      help="The ZMQ URL of the iplogger to aggregate logging.")

    cluster_id = Unicode(
        '',
        config=True,
        help=
        """String id to add to runtime files, to prevent name collisions when
        using multiple clusters with a single profile simultaneously.
        
        When set, files will be named like: 'ipcontroller-<cluster_id>-engine.json'
        
        Since this is text inserted into filenames, typical recommendations apply:
        Simple character strings are ideal, and spaces are not recommended (but should
        generally work).
        """)

    def _cluster_id_changed(self, name, old, new):
        self.name = self.__class__.name
        if new:
            self.name += '-%s' % new

    def _config_files_default(self):
        return [
            'ipcontroller_config.py', 'ipengine_config.py',
            'ipcluster_config.py'
        ]

    loop = Instance('zmq.eventloop.ioloop.IOLoop')

    def _loop_default(self):
        from zmq.eventloop.ioloop import IOLoop
        return IOLoop.instance()

    aliases = Dict(base_aliases)
    flags = Dict(base_flags)

    @catch_config_error
    def initialize(self, argv=None):
        """initialize the app"""
        super(BaseParallelApplication, self).initialize(argv)
        self.to_work_dir()
        self.reinit_logging()

    def to_work_dir(self):
        wd = self.work_dir
        if unicode(wd) != os.getcwdu():
            os.chdir(wd)
            self.log.info("Changing to working dir: %s" % wd)
        # This is the working dir by now.
        sys.path.insert(0, '')

    def reinit_logging(self):
        # Remove old log files
        log_dir = self.profile_dir.log_dir
        if self.clean_logs:
            for f in os.listdir(log_dir):
                if re.match(r'%s-\d+\.(log|err|out)' % self.name, f):
                    os.remove(os.path.join(log_dir, f))
        if self.log_to_file:
            # Start logging to the new log file
            log_filename = self.name + u'-' + str(os.getpid()) + u'.log'
            logfile = os.path.join(log_dir, log_filename)
            open_log_file = open(logfile, 'w')
        else:
            open_log_file = None
        if open_log_file is not None:
            self.log.removeHandler(self._log_handler)
            self._log_handler = logging.StreamHandler(open_log_file)
            self._log_formatter = logging.Formatter("[%(name)s] %(message)s")
            self._log_handler.setFormatter(self._log_formatter)
            self.log.addHandler(self._log_handler)
        # do not propagate log messages to root logger
        # ipcluster app will sometimes print duplicate messages during shutdown
        # if this is 1 (default):
        self.log.propagate = False

    def write_pid_file(self, overwrite=False):
        """Create a .pid file in the pid_dir with my pid.

        This must be called after pre_construct, which sets `self.pid_dir`.
        This raises :exc:`PIDFileError` if the pid file exists already.
        """
        pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
        if os.path.isfile(pid_file):
            pid = self.get_pid_from_file()
            if not overwrite:
                raise PIDFileError(
                    'The pid file [%s] already exists. \nThis could mean that this '
                    'server is already running with [pid=%s].' %
                    (pid_file, pid))
        with open(pid_file, 'w') as f:
            self.log.info("Creating pid file: %s" % pid_file)
            f.write(repr(os.getpid()) + '\n')

    def remove_pid_file(self):
        """Remove the pid file.

        This should be called at shutdown by registering a callback with
        :func:`reactor.addSystemEventTrigger`. This needs to return
        ``None``.
        """
        pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
        if os.path.isfile(pid_file):
            try:
                self.log.info("Removing pid file: %s" % pid_file)
                os.remove(pid_file)
            except:
                self.log.warn("Error removing the pid file: %s" % pid_file)

    def get_pid_from_file(self):
        """Get the pid from the pid file.

        If the  pid file doesn't exist a :exc:`PIDFileError` is raised.
        """
        pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
        if os.path.isfile(pid_file):
            with open(pid_file, 'r') as f:
                s = f.read().strip()
                try:
                    pid = int(s)
                except:
                    raise PIDFileError("invalid pid file: %s (contents: %r)" %
                                       (pid_file, s))
                return pid
        else:
            raise PIDFileError('pid file not found: %s' % pid_file)

    def check_pid(self, pid):
        if os.name == 'nt':
            try:
                import ctypes
                # returns 0 if no such process (of ours) exists
                # positive int otherwise
                p = ctypes.windll.kernel32.OpenProcess(1, 0, pid)
            except Exception:
                self.log.warn(
                    "Could not determine whether pid %i is running via `OpenProcess`. "
                    " Making the likely assumption that it is." % pid)
                return True
            return bool(p)
        else:
            try:
                p = Popen(['ps', 'x'], stdout=PIPE, stderr=PIPE)
                output, _ = p.communicate()
            except OSError:
                self.log.warn(
                    "Could not determine whether pid %i is running via `ps x`. "
                    " Making the likely assumption that it is." % pid)
                return True
            pids = map(int, re.findall(r'^\W*\d+', output, re.MULTILINE))
            return pid in pids
Esempio n. 14
0
class CommManager(LoggingConfigurable):
    """Manager for Comms in the Kernel"""

    # If this is instantiated by a non-IPython kernel, shell will be None
    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
                     allow_none=True)
    kernel = Instance('IPython.kernel.zmq.kernelbase.Kernel')

    iopub_socket = Any()

    def _iopub_socket_default(self):
        return self.kernel.iopub_socket

    session = Instance('IPython.kernel.zmq.session.Session')

    def _session_default(self):
        return self.kernel.session

    comms = Dict()
    targets = Dict()

    # Public APIs

    def register_target(self, target_name, f):
        """Register a callable f for a given target name
        
        f will be called with two arguments when a comm_open message is received with `target`:
        
        - the Comm instance
        - the `comm_open` message itself.
        
        f can be a Python callable or an import string for one.
        """
        if isinstance(f, string_types):
            f = import_item(f)

        self.targets[target_name] = f

    def unregister_target(self, target_name, f):
        """Unregister a callable registered with register_target"""
        return self.targets.pop(target_name)

    def register_comm(self, comm):
        """Register a new comm"""
        comm_id = comm.comm_id
        comm.shell = self.shell
        comm.kernel = self.kernel
        comm.iopub_socket = self.iopub_socket
        self.comms[comm_id] = comm
        return comm_id

    def unregister_comm(self, comm):
        """Unregister a comm, and close its counterpart"""
        # unlike get_comm, this should raise a KeyError
        comm = self.comms.pop(comm.comm_id)

    def get_comm(self, comm_id):
        """Get a comm with a particular id
        
        Returns the comm if found, otherwise None.
        
        This will not raise an error,
        it will log messages if the comm cannot be found.
        """
        if comm_id not in self.comms:
            self.log.error("No such comm: %s", comm_id)
            self.log.debug("Current comms: %s", lazy_keys(self.comms))
            return
        # call, because we store weakrefs
        comm = self.comms[comm_id]
        return comm

    # Message handlers
    def comm_open(self, stream, ident, msg):
        """Handler for comm_open messages"""
        content = msg['content']
        comm_id = content['comm_id']
        target_name = content['target_name']
        f = self.targets.get(target_name, None)
        comm = Comm(
            comm_id=comm_id,
            shell=self.shell,
            kernel=self.kernel,
            iopub_socket=self.iopub_socket,
            primary=False,
        )
        self.register_comm(comm)
        if f is None:
            self.log.error("No such comm target registered: %s", target_name)
        else:
            try:
                f(comm, msg)
                return
            except Exception:
                self.log.error("Exception opening comm with target: %s",
                               target_name,
                               exc_info=True)

        # Failure.
        try:
            comm.close()
        except:
            self.log.error("""Could not close comm during `comm_open` failure 
                clean-up.  The comm may not have been opened yet.""",
                           exc_info=True)

    def comm_msg(self, stream, ident, msg):
        """Handler for comm_msg messages"""
        content = msg['content']
        comm_id = content['comm_id']
        comm = self.get_comm(comm_id)
        if comm is None:
            # no such comm
            return
        try:
            comm.handle_msg(msg)
        except Exception:
            self.log.error("Exception in comm_msg for %s",
                           comm_id,
                           exc_info=True)

    def comm_close(self, stream, ident, msg):
        """Handler for comm_close messages"""
        content = msg['content']
        comm_id = content['comm_id']
        comm = self.get_comm(comm_id)
        if comm is None:
            # no such comm
            self.log.debug("No such comm to close: %s", comm_id)
            return
        del self.comms[comm_id]

        try:
            comm.handle_close(msg)
        except Exception:
            self.log.error("Exception handling comm_close for %s",
                           comm_id,
                           exc_info=True)
Esempio n. 15
0
class KernelManager(HasTraits):
    """ Manages a kernel for a frontend.

    The SUB channel is for the frontend to receive messages published by the
    kernel.

    The REQ channel is for the frontend to make requests of the kernel.

    The REP channel is for the kernel to request stdin (raw_input) from the
    frontend.
    """
    # config object for passing to child configurables
    config = Instance(Config)

    # The PyZMQ Context to use for communication with the kernel.
    context = Instance(zmq.Context)
    def _context_default(self):
        return zmq.Context.instance()

    # The Session to use for communication with the kernel.
    session = Instance(Session)

    # The kernel process with which the KernelManager is communicating.
    kernel = Instance(Popen)

    # The addresses for the communication channels.
    connection_file = Unicode('')
    ip = Unicode(LOCALHOST)
    def _ip_changed(self, name, old, new):
        if new == '*':
            self.ip = '0.0.0.0'
    shell_port = Int(0)
    iopub_port = Int(0)
    stdin_port = Int(0)
    hb_port = Int(0)

    # The classes to use for the various channels.
    shell_channel_class = Type(ShellSocketChannel)
    sub_channel_class = Type(SubSocketChannel)
    stdin_channel_class = Type(StdInSocketChannel)
    hb_channel_class = Type(HBSocketChannel)

    # Protected traits.
    _launch_args = Any
    _shell_channel = Any
    _sub_channel = Any
    _stdin_channel = Any
    _hb_channel = Any
    _connection_file_written=Bool(False)

    def __init__(self, **kwargs):
        super(KernelManager, self).__init__(**kwargs)
        if self.session is None:
            self.session = Session(config=self.config)
    
    def __del__(self):
        if self._connection_file_written:
            # cleanup connection files on full shutdown of kernel we started
            self._connection_file_written = False
            try:
                os.remove(self.connection_file)
            except IOError:
                pass


    #--------------------------------------------------------------------------
    # Channel management methods:
    #--------------------------------------------------------------------------

    def start_channels(self, shell=True, sub=True, stdin=True, hb=True):
        """Starts the channels for this kernel.

        This will create the channels if they do not exist and then start
        them. If port numbers of 0 are being used (random ports) then you
        must first call :method:`start_kernel`. If the channels have been
        stopped and you call this, :class:`RuntimeError` will be raised.
        """
        if shell:
            self.shell_channel.start()
        if sub:
            self.sub_channel.start()
        if stdin:
            self.stdin_channel.start()
            self.shell_channel.allow_stdin = True
        else:
            self.shell_channel.allow_stdin = False
        if hb:
            self.hb_channel.start()

    def stop_channels(self):
        """Stops all the running channels for this kernel.
        """
        if self.shell_channel.is_alive():
            self.shell_channel.stop()
        if self.sub_channel.is_alive():
            self.sub_channel.stop()
        if self.stdin_channel.is_alive():
            self.stdin_channel.stop()
        if self.hb_channel.is_alive():
            self.hb_channel.stop()

    @property
    def channels_running(self):
        """Are any of the channels created and running?"""
        return (self.shell_channel.is_alive() or self.sub_channel.is_alive() or
                self.stdin_channel.is_alive() or self.hb_channel.is_alive())

    #--------------------------------------------------------------------------
    # Kernel process management methods:
    #--------------------------------------------------------------------------
    
    def load_connection_file(self):
        """load connection info from JSON dict in self.connection_file"""
        with open(self.connection_file) as f:
            cfg = json.loads(f.read())
        
        self.ip = cfg['ip']
        self.shell_port = cfg['shell_port']
        self.stdin_port = cfg['stdin_port']
        self.iopub_port = cfg['iopub_port']
        self.hb_port = cfg['hb_port']
        self.session.key = str_to_bytes(cfg['key'])
    
    def write_connection_file(self):
        """write connection info to JSON dict in self.connection_file"""
        if self._connection_file_written:
            return
        self.connection_file,cfg = write_connection_file(self.connection_file,
            ip=self.ip, key=self.session.key,
            stdin_port=self.stdin_port, iopub_port=self.iopub_port,
            shell_port=self.shell_port, hb_port=self.hb_port)
        # write_connection_file also sets default ports:
        self.shell_port = cfg['shell_port']
        self.stdin_port = cfg['stdin_port']
        self.iopub_port = cfg['iopub_port']
        self.hb_port = cfg['hb_port']
        
        self._connection_file_written = True
    
    def start_kernel(self, **kw):
        """Starts a kernel process and configures the manager to use it.

        If random ports (port=0) are being used, this method must be called
        before the channels are created.

        Parameters:
        -----------
        ipython : bool, optional (default True)
             Whether to use an IPython kernel instead of a plain Python kernel.

        launcher : callable, optional (default None)
             A custom function for launching the kernel process (generally a
             wrapper around ``entry_point.base_launch_kernel``). In most cases,
             it should not be necessary to use this parameter.

        **kw : optional
             See respective options for IPython and Python kernels.
        """
        if self.ip not in LOCAL_IPS:
            raise RuntimeError("Can only launch a kernel on a local interface. "
                               "Make sure that the '*_address' attributes are "
                               "configured properly. "
                               "Currently valid addresses are: %s"%LOCAL_IPS
                               )
        
        # write connection file / get default ports
        self.write_connection_file()

        self._launch_args = kw.copy()
        launch_kernel = kw.pop('launcher', None)
        if launch_kernel is None:
            if kw.pop('ipython', True):
                from ipkernel import launch_kernel
            else:
                from pykernel import launch_kernel
        self.kernel = launch_kernel(fname=self.connection_file, **kw)

    def shutdown_kernel(self, restart=False):
        """ Attempts to the stop the kernel process cleanly. If the kernel
        cannot be stopped, it is killed, if possible.
        """
        # FIXME: Shutdown does not work on Windows due to ZMQ errors!
        if sys.platform == 'win32':
            self.kill_kernel()
            return

        # Pause the heart beat channel if it exists.
        if self._hb_channel is not None:
            self._hb_channel.pause()

        # Don't send any additional kernel kill messages immediately, to give
        # the kernel a chance to properly execute shutdown actions. Wait for at
        # most 1s, checking every 0.1s.
        self.shell_channel.shutdown(restart=restart)
        for i in range(10):
            if self.is_alive:
                time.sleep(0.1)
            else:
                break
        else:
            # OK, we've waited long enough.
            if self.has_kernel:
                self.kill_kernel()

        if not restart and self._connection_file_written:
            # cleanup connection files on full shutdown of kernel we started
            self._connection_file_written = False
            try:
                os.remove(self.connection_file)
            except IOError:
                pass

    def restart_kernel(self, now=False, **kw):
        """Restarts a kernel with the arguments that were used to launch it.

        If the old kernel was launched with random ports, the same ports will be
        used for the new kernel.

        Parameters
        ----------
        now : bool, optional
            If True, the kernel is forcefully restarted *immediately*, without
            having a chance to do any cleanup action.  Otherwise the kernel is
            given 1s to clean up before a forceful restart is issued.

            In all cases the kernel is restarted, the only difference is whether
            it is given a chance to perform a clean shutdown or not.

        **kw : optional
            Any options specified here will replace those used to launch the
            kernel.
        """
        if self._launch_args is None:
            raise RuntimeError("Cannot restart the kernel. "
                               "No previous call to 'start_kernel'.")
        else:
            # Stop currently running kernel.
            if self.has_kernel:
                if now:
                    self.kill_kernel()
                else:
                    self.shutdown_kernel(restart=True)

            # Start new kernel.
            self._launch_args.update(kw)
            self.start_kernel(**self._launch_args)

            # FIXME: Messages get dropped in Windows due to probable ZMQ bug
            # unless there is some delay here.
            if sys.platform == 'win32':
                time.sleep(0.2)

    @property
    def has_kernel(self):
        """Returns whether a kernel process has been specified for the kernel
        manager.
        """
        return self.kernel is not None

    def kill_kernel(self):
        """ Kill the running kernel. """
        if self.has_kernel:
            # Pause the heart beat channel if it exists.
            if self._hb_channel is not None:
                self._hb_channel.pause()

            # Attempt to kill the kernel.
            try:
                self.kernel.kill()
            except OSError, e:
                # In Windows, we will get an Access Denied error if the process
                # has already terminated. Ignore it.
                if sys.platform == 'win32':
                    if e.winerror != 5:
                        raise
                # On Unix, we may get an ESRCH error if the process has already
                # terminated. Ignore it.
                else:
                    from errno import ESRCH
                    if e.errno != ESRCH:
                        raise
            self.kernel = None
        else:
Esempio n. 16
0
class NbConvertApp(BaseIPythonApplication):
    """Application used to convert to and from notebook file type (*.ipynb)"""

    name = 'ipython-nbconvert'
    aliases = nbconvert_aliases
    flags = nbconvert_flags
    
    def _log_level_default(self):
        return logging.INFO
    
    def _classes_default(self):
        classes = [NbConvertBase]
        for pkg in (exporters, transformers, writers):
            for name in dir(pkg):
                cls = getattr(pkg, name)
                if isinstance(cls, type) and issubclass(cls, Configurable):
                    classes.append(cls)
        return classes

    description = Unicode(
        u"""This application is used to convert notebook files (*.ipynb)
        to various other formats.

        WARNING: THE COMMANDLINE INTERFACE MAY CHANGE IN FUTURE RELEASES.""")

    output_base = Unicode('', config=True, help='''overwrite base name use for output files.
            can only  be use when converting one notebook at a time.
            ''')

    examples = Unicode(u"""
        The simplest way to use nbconvert is
        
        > ipython nbconvert mynotebook.ipynb
        
        which will convert mynotebook.ipynb to the default format (probably HTML).
        
        You can specify the export format with `--to`.
        Options include {0}
        
        > ipython nbconvert --to latex mynotebook.ipnynb

        Both HTML and LaTeX support multiple output templates. LaTeX includes
        'basic', 'book', and 'article'.  HTML includes 'basic' and 'full'.  You 
        can specify the flavor of the format used.

        > ipython nbconvert --to html --template basic mynotebook.ipynb
        
        You can also pipe the output to stdout, rather than a file
        
        > ipython nbconvert mynotebook.ipynb --stdout

        A post-processor can be used to compile a PDF

        > ipython nbconvert mynotebook.ipynb --to latex --post PDF
        
        You can get (and serve) a Reveal.js-powered slideshow
        
        > ipython nbconvert myslides.ipynb --to slides --post serve
        
        Multiple notebooks can be given at the command line in a couple of 
        different ways:
  
        > ipython nbconvert notebook*.ipynb
        > ipython nbconvert notebook1.ipynb notebook2.ipynb
        
        or you can specify the notebooks list in a config file, containing::
        
            c.NbConvertApp.notebooks = ["my_notebook.ipynb"]
        
        > ipython nbconvert --config mycfg.py
        """.format(get_export_names()))

    # Writer specific variables
    writer = Instance('IPython.nbconvert.writers.base.WriterBase',  
                      help="""Instance of the writer class used to write the 
                      results of the conversion.""")
    writer_class = DottedObjectName('FilesWriter', config=True, 
                                    help="""Writer class used to write the 
                                    results of the conversion""")
    writer_aliases = {'FilesWriter': 'IPython.nbconvert.writers.files.FilesWriter',
                      'DebugWriter': 'IPython.nbconvert.writers.debug.DebugWriter',
                      'StdoutWriter': 'IPython.nbconvert.writers.stdout.StdoutWriter'}
    writer_factory = Type()

    def _writer_class_changed(self, name, old, new):
        if new in self.writer_aliases:
            new = self.writer_aliases[new]
        self.writer_factory = import_item(new)

    # Post-processor specific variables
    post_processor = Instance('IPython.nbconvert.post_processors.base.PostProcessorBase',  
                      help="""Instance of the PostProcessor class used to write the 
                      results of the conversion.""")

    post_processor_class = DottedOrNone(config=True, 
                                    help="""PostProcessor class used to write the 
                                    results of the conversion""")
    post_processor_aliases = {'PDF': 'IPython.nbconvert.post_processors.pdf.PDFPostProcessor',
                              'serve': 'IPython.nbconvert.post_processors.serve.ServePostProcessor'}
    post_processor_factory = Type()

    def _post_processor_class_changed(self, name, old, new):
        if new in self.post_processor_aliases:
            new = self.post_processor_aliases[new]
        if new:
            self.post_processor_factory = import_item(new)


    # Other configurable variables
    export_format = CaselessStrEnum(get_export_names(),
        default_value="html",
        config=True,
        help="""The export format to be used."""
    )

    notebooks = List([], config=True, help="""List of notebooks to convert.
                     Wildcards are supported.
                     Filenames passed positionally will be added to the list.
                     """)

    @catch_config_error
    def initialize(self, argv=None):
        super(NbConvertApp, self).initialize(argv)
        self.init_syspath()
        self.init_notebooks()
        self.init_writer()
        self.init_post_processor()



    def init_syspath(self):
        """
        Add the cwd to the sys.path ($PYTHONPATH)
        """
        sys.path.insert(0, os.getcwd())
        

    def init_notebooks(self):
        """Construct the list of notebooks.
        If notebooks are passed on the command-line,
        they override notebooks specified in config files.
        Glob each notebook to replace notebook patterns with filenames.
        """

        # Specifying notebooks on the command-line overrides (rather than adds)
        # the notebook list
        if self.extra_args:
            patterns = self.extra_args
        else:
            patterns = self.notebooks

        # Use glob to replace all the notebook patterns with filenames.
        filenames = []
        for pattern in patterns:
            
            # Use glob to find matching filenames.  Allow the user to convert 
            # notebooks without having to type the extension.
            globbed_files = glob.glob(pattern)
            globbed_files.extend(glob.glob(pattern + '.ipynb'))
            if not globbed_files:
                self.log.warn("pattern %r matched no files", pattern)

            for filename in globbed_files:
                if not filename in filenames:
                    filenames.append(filename)
        self.notebooks = filenames

    def init_writer(self):
        """
        Initialize the writer (which is stateless)
        """
        self._writer_class_changed(None, self.writer_class, self.writer_class)
        self.writer = self.writer_factory(parent=self)

    def init_post_processor(self):
        """
        Initialize the post_processor (which is stateless)
        """
        self._post_processor_class_changed(None, self.post_processor_class, 
            self.post_processor_class)
        if self.post_processor_factory:
            self.post_processor = self.post_processor_factory(parent=self)

    def start(self):
        """
        Ran after initialization completed
        """
        super(NbConvertApp, self).start()
        self.convert_notebooks()

    def convert_notebooks(self):
        """
        Convert the notebooks in the self.notebook traitlet
        """
        # Export each notebook
        conversion_success = 0

        if self.output_base != '' and len(self.notebooks) > 1:
            self.log.error(
            """UsageError: --output flag or `NbConvertApp.output_base` config option
            cannot be used when converting multiple notebooks.
            """)
            self.exit(1)
        
        exporter = exporter_map[self.export_format](config=self.config)

        for notebook_filename in self.notebooks:
            self.log.info("Converting notebook %s to %s", notebook_filename, self.export_format)

            # Get a unique key for the notebook and set it in the resources object.
            basename = os.path.basename(notebook_filename)
            notebook_name = basename[:basename.rfind('.')]
            if self.output_base:
                notebook_name = self.output_base
            resources = {}
            resources['unique_key'] = notebook_name
            resources['output_files_dir'] = '%s_files' % notebook_name
            self.log.info("Support files will be in %s", os.path.join(resources['output_files_dir'], ''))

            # Try to export
            try:
                output, resources = exporter.from_filename(notebook_filename, resources=resources)
            except ConversionException as e:
                self.log.error("Error while converting '%s'", notebook_filename,
                      exc_info=True)
                self.exit(1)
            else:
                write_resultes = self.writer.write(output, resources, notebook_name=notebook_name)

                #Post-process if post processor has been defined.
                if hasattr(self, 'post_processor') and self.post_processor:
                    self.post_processor(write_resultes)
                conversion_success += 1

        # If nothing was converted successfully, help the user.
        if conversion_success == 0:
            self.print_help()
            sys.exit(-1)
Esempio n. 17
0
class BaseLauncher(LoggingFactory):
    """An asbtraction for starting, stopping and signaling a process."""

    # In all of the launchers, the work_dir is where child processes will be
    # run. This will usually be the cluster_dir, but may not be. any work_dir
    # passed into the __init__ method will override the config value.
    # This should not be used to set the work_dir for the actual engine
    # and controller. Instead, use their own config files or the
    # controller_args, engine_args attributes of the launchers to add
    # the --work-dir option.
    work_dir = Unicode(u'.')
    loop = Instance('zmq.eventloop.ioloop.IOLoop')

    start_data = Any()
    stop_data = Any()

    def _loop_default(self):
        return ioloop.IOLoop.instance()

    def __init__(self, work_dir=u'.', config=None, **kwargs):
        super(BaseLauncher, self).__init__(work_dir=work_dir,
                                           config=config,
                                           **kwargs)
        self.state = 'before'  # can be before, running, after
        self.stop_callbacks = []
        self.start_data = None
        self.stop_data = None

    @property
    def args(self):
        """A list of cmd and args that will be used to start the process.

        This is what is passed to :func:`spawnProcess` and the first element
        will be the process name.
        """
        return self.find_args()

    def find_args(self):
        """The ``.args`` property calls this to find the args list.

        Subcommand should implement this to construct the cmd and args.
        """
        raise NotImplementedError(
            'find_args must be implemented in a subclass')

    @property
    def arg_str(self):
        """The string form of the program arguments."""
        return ' '.join(self.args)

    @property
    def running(self):
        """Am I running."""
        if self.state == 'running':
            return True
        else:
            return False

    def start(self):
        """Start the process.

        This must return a deferred that fires with information about the
        process starting (like a pid, job id, etc.).
        """
        raise NotImplementedError('start must be implemented in a subclass')

    def stop(self):
        """Stop the process and notify observers of stopping.

        This must return a deferred that fires with information about the
        processing stopping, like errors that occur while the process is
        attempting to be shut down. This deferred won't fire when the process
        actually stops. To observe the actual process stopping, see
        :func:`observe_stop`.
        """
        raise NotImplementedError('stop must be implemented in a subclass')

    def on_stop(self, f):
        """Get a deferred that will fire when the process stops.

        The deferred will fire with data that contains information about
        the exit status of the process.
        """
        if self.state == 'after':
            return f(self.stop_data)
        else:
            self.stop_callbacks.append(f)

    def notify_start(self, data):
        """Call this to trigger startup actions.

        This logs the process startup and sets the state to 'running'.  It is
        a pass-through so it can be used as a callback.
        """

        self.log.info('Process %r started: %r' % (self.args[0], data))
        self.start_data = data
        self.state = 'running'
        return data

    def notify_stop(self, data):
        """Call this to trigger process stop actions.

        This logs the process stopping and sets the state to 'after'. Call
        this to trigger all the deferreds from :func:`observe_stop`."""

        self.log.info('Process %r stopped: %r' % (self.args[0], data))
        self.stop_data = data
        self.state = 'after'
        for i in range(len(self.stop_callbacks)):
            d = self.stop_callbacks.pop()
            d(data)
        return data

    def signal(self, sig):
        """Signal the process.

        Return a semi-meaningless deferred after signaling the process.

        Parameters
        ----------
        sig : str or int
            'KILL', 'INT', etc., or any signal number
        """
        raise NotImplementedError('signal must be implemented in a subclass')
Esempio n. 18
0
class Comm(LoggingConfigurable):
    """Class for communicating between a Frontend and a Kernel"""
    # If this is instantiated by a non-IPython kernel, shell will be None
    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
                     allow_none=True)
    kernel = Instance('ipython_kernel.kernelbase.Kernel')

    def _kernel_default(self):
        if Kernel.initialized():
            return Kernel.instance()

    iopub_socket = Any()

    def _iopub_socket_default(self):
        return self.kernel.iopub_socket

    session = Instance('ipython_kernel.session.Session')

    def _session_default(self):
        if self.kernel is not None:
            return self.kernel.session

    target_name = Unicode('comm')
    target_module = Unicode(None,
                            allow_none=True,
                            help="""requirejs module from
        which to load comm target.""")

    topic = Bytes()

    def _topic_default(self):
        return ('comm-%s' % self.comm_id).encode('ascii')

    _open_data = Dict(help="data dict, if any, to be included in comm_open")
    _close_data = Dict(help="data dict, if any, to be included in comm_close")

    _msg_callback = Any()
    _close_callback = Any()

    _closed = Bool(True)
    comm_id = Unicode()

    def _comm_id_default(self):
        return uuid.uuid4().hex

    primary = Bool(True, help="Am I the primary or secondary Comm?")

    def __init__(self, target_name='', data=None, **kwargs):
        if target_name:
            kwargs['target_name'] = target_name
        super(Comm, self).__init__(**kwargs)
        if self.primary:
            # I am primary, open my peer.
            self.open(data)
        else:
            self._closed = False

    def _publish_msg(self,
                     msg_type,
                     data=None,
                     metadata=None,
                     buffers=None,
                     **keys):
        """Helper for sending a comm message on IOPub"""
        if threading.current_thread(
        ).name != 'MainThread' and IOLoop.initialized():
            # make sure we never send on a zmq socket outside the main IOLoop thread
            IOLoop.instance().add_callback(lambda: self._publish_msg(
                msg_type, data, metadata, buffers, **keys))
            return
        data = {} if data is None else data
        metadata = {} if metadata is None else metadata
        content = json_clean(dict(data=data, comm_id=self.comm_id, **keys))
        self.session.send(
            self.iopub_socket,
            msg_type,
            content,
            metadata=json_clean(metadata),
            parent=self.kernel._parent_header,
            ident=self.topic,
            buffers=buffers,
        )

    def __del__(self):
        """trigger close on gc"""
        self.close()

    # publishing messages

    def open(self, data=None, metadata=None, buffers=None):
        """Open the frontend-side version of this comm"""
        if data is None:
            data = self._open_data
        comm_manager = getattr(self.kernel, 'comm_manager', None)
        if comm_manager is None:
            raise RuntimeError("Comms cannot be opened without a kernel "
                               "and a comm_manager attached to that kernel.")

        comm_manager.register_comm(self)
        try:
            self._publish_msg(
                'comm_open',
                data=data,
                metadata=metadata,
                buffers=buffers,
                target_name=self.target_name,
                target_module=self.target_module,
            )
            self._closed = False
        except:
            comm_manager.unregister_comm(self)
            raise

    def close(self, data=None, metadata=None, buffers=None):
        """Close the frontend-side version of this comm"""
        if self._closed:
            # only close once
            return
        self._closed = True
        if data is None:
            data = self._close_data
        self._publish_msg(
            'comm_close',
            data=data,
            metadata=metadata,
            buffers=buffers,
        )
        self.kernel.comm_manager.unregister_comm(self)

    def send(self, data=None, metadata=None, buffers=None):
        """Send a message to the frontend-side version of this comm"""
        self._publish_msg(
            'comm_msg',
            data=data,
            metadata=metadata,
            buffers=buffers,
        )

    # registering callbacks

    def on_close(self, callback):
        """Register a callback for comm_close

        Will be called with the `data` of the close message.

        Call `on_close(None)` to disable an existing callback.
        """
        self._close_callback = callback

    def on_msg(self, callback):
        """Register a callback for comm_msg

        Will be called with the `data` of any comm_msg messages.

        Call `on_msg(None)` to disable an existing callback.
        """
        self._msg_callback = callback

    # handling of incoming messages

    def handle_close(self, msg):
        """Handle a comm_close message"""
        self.log.debug("handle_close[%s](%s)", self.comm_id, msg)
        if self._close_callback:
            self._close_callback(msg)

    def handle_msg(self, msg):
        """Handle a comm_msg message"""
        self.log.debug("handle_msg[%s](%s)", self.comm_id, msg)
        if self._msg_callback:
            if self.shell:
                self.shell.events.trigger('pre_execute')
            self._msg_callback(msg)
            if self.shell:
                self.shell.events.trigger('post_execute')
Esempio n. 19
0
 class A(HasTraitlets):
     inst = Instance(Foo)
Esempio n. 20
0
class Application(SingletonConfigurable):
    """A singleton application with full configuration support."""

    # The name of the application, will usually match the name of the command
    # line application
    name = Unicode(u'application')

    # The description of the application that is printed at the beginning
    # of the help.
    description = Unicode(u'This is an application.')
    # default section descriptions
    option_description = Unicode(option_description)
    keyvalue_description = Unicode(keyvalue_description)
    subcommand_description = Unicode(subcommand_description)

    # The usage and example string that goes at the end of the help string.
    examples = Unicode()

    # A sequence of Configurable subclasses whose config=True attributes will
    # be exposed at the command line.
    classes = []

    @property
    def _help_classes(self):
        """Define `App.help_classes` if CLI classes should differ from config file classes"""
        return getattr(self, 'help_classes', self.classes)

    @property
    def _config_classes(self):
        """Define `App.config_classes` if config file classes should differ from CLI classes."""
        return getattr(self, 'config_classes', self.classes)

    # The version string of this application.
    version = Unicode(u'0.0')

    # the argv used to initialize the application
    argv = List()

    # The log level for the application
    log_level = Enum(
        (0, 10, 20, 30, 40, 50, 'DEBUG', 'INFO', 'WARN', 'ERROR', 'CRITICAL'),
        default_value=logging.WARN,
        config=True,
        help="Set the log level by value or name.")

    def _log_level_changed(self, name, old, new):
        """Adjust the log level when log_level is set."""
        if isinstance(new, string_types):
            new = getattr(logging, new)
            self.log_level = new
        self.log.setLevel(new)

    _log_formatter_cls = LevelFormatter

    log_datefmt = Unicode(
        "%Y-%m-%d %H:%M:%S",
        config=True,
        help="The date format used by logging formatters for %(asctime)s")

    def _log_datefmt_changed(self, name, old, new):
        self._log_format_changed('log_format', self.log_format,
                                 self.log_format)

    log_format = Unicode(
        "[%(name)s]%(highlevel)s %(message)s",
        config=True,
        help="The Logging format template",
    )

    def _log_format_changed(self, name, old, new):
        """Change the log formatter when log_format is set."""
        _log_handler = self.log.handlers[0]
        _log_formatter = self._log_formatter_cls(fmt=new,
                                                 datefmt=self.log_datefmt)
        _log_handler.setFormatter(_log_formatter)

    log = Instance(logging.Logger)

    def _log_default(self):
        """Start logging for this application.

        The default is to log to stderr using a StreamHandler, if no default
        handler already exists.  The log level starts at logging.WARN, but this
        can be adjusted by setting the ``log_level`` attribute.
        """
        log = logging.getLogger(self.__class__.__name__)
        log.setLevel(self.log_level)
        log.propagate = False
        _log = log  # copied from Logger.hasHandlers() (new in Python 3.2)
        while _log:
            if _log.handlers:
                return log
            if not _log.propagate:
                break
            else:
                _log = _log.parent
        if sys.executable.endswith('pythonw.exe'):
            # this should really go to a file, but file-logging is only
            # hooked up in parallel applications
            _log_handler = logging.StreamHandler(open(os.devnull, 'w'))
        else:
            _log_handler = logging.StreamHandler()
        _log_formatter = self._log_formatter_cls(fmt=self.log_format,
                                                 datefmt=self.log_datefmt)
        _log_handler.setFormatter(_log_formatter)
        log.addHandler(_log_handler)
        return log

    # the alias map for configurables
    aliases = Dict({'log-level': 'Application.log_level'})

    # flags for loading Configurables or store_const style flags
    # flags are loaded from this dict by '--key' flags
    # this must be a dict of two-tuples, the first element being the Config/dict
    # and the second being the help string for the flag
    flags = Dict()

    def _flags_changed(self, name, old, new):
        """ensure flags dict is valid"""
        for key, value in iteritems(new):
            assert len(value) == 2, "Bad flag: %r:%s" % (key, value)
            assert isinstance(value[0],
                              (dict, Config)), "Bad flag: %r:%s" % (key, value)
            assert isinstance(value[1],
                              string_types), "Bad flag: %r:%s" % (key, value)

    # subcommands for launching other applications
    # if this is not empty, this will be a parent Application
    # this must be a dict of two-tuples,
    # the first element being the application class/import string
    # and the second being the help string for the subcommand
    subcommands = Dict()
    # parse_command_line will initialize a subapp, if requested
    subapp = Instance('IPython.config.application.Application',
                      allow_none=True)

    # extra command-line arguments that don't set config values
    extra_args = List(Unicode)

    def __init__(self, **kwargs):
        SingletonConfigurable.__init__(self, **kwargs)
        # Ensure my class is in self.classes, so my attributes appear in command line
        # options and config files.
        if self.__class__ not in self.classes:
            self.classes.insert(0, self.__class__)

    def _config_changed(self, name, old, new):
        SingletonConfigurable._config_changed(self, name, old, new)
        self.log.debug('Config changed:')
        self.log.debug(repr(new))

    @catch_config_error
    def initialize(self, argv=None):
        """Do the basic steps to configure me.

        Override in subclasses.
        """
        self.parse_command_line(argv)

    def start(self):
        """Start the app mainloop.

        Override in subclasses.
        """
        if self.subapp is not None:
            return self.subapp.start()

    def print_alias_help(self):
        """Print the alias part of the help."""
        if not self.aliases:
            return

        lines = []
        classdict = {}
        for cls in self._help_classes:
            # include all parents (up to, but excluding Configurable) in available names
            for c in cls.mro()[:-3]:
                classdict[c.__name__] = c

        for alias, longname in iteritems(self.aliases):
            classname, traitname = longname.split('.', 1)
            cls = classdict[classname]

            trait = cls.class_traits(config=True)[traitname]
            help = cls.class_get_trait_help(trait).splitlines()
            # reformat first line
            help[0] = help[0].replace(longname, alias) + ' (%s)' % longname
            if len(alias) == 1:
                help[0] = help[0].replace('--%s=' % alias, '-%s ' % alias)
            lines.extend(help)
        # lines.append('')
        print(os.linesep.join(lines))

    def print_flag_help(self):
        """Print the flag part of the help."""
        if not self.flags:
            return

        lines = []
        for m, (cfg, help) in iteritems(self.flags):
            prefix = '--' if len(m) > 1 else '-'
            lines.append(prefix + m)
            lines.append(indent(dedent(help.strip())))
        # lines.append('')
        print(os.linesep.join(lines))

    def print_options(self):
        if not self.flags and not self.aliases:
            return
        lines = ['Options']
        lines.append('-' * len(lines[0]))
        lines.append('')
        for p in wrap_paragraphs(self.option_description):
            lines.append(p)
            lines.append('')
        print(os.linesep.join(lines))
        self.print_flag_help()
        self.print_alias_help()
        print()

    def print_subcommands(self):
        """Print the subcommand part of the help."""
        if not self.subcommands:
            return

        lines = ["Subcommands"]
        lines.append('-' * len(lines[0]))
        lines.append('')
        for p in wrap_paragraphs(
                self.subcommand_description.format(app=self.name)):
            lines.append(p)
            lines.append('')
        for subc, (cls, help) in iteritems(self.subcommands):
            lines.append(subc)
            if help:
                lines.append(indent(dedent(help.strip())))
        lines.append('')
        print(os.linesep.join(lines))

    def print_help(self, classes=False):
        """Print the help for each Configurable class in self.classes.

        If classes=False (the default), only flags and aliases are printed.
        """
        self.print_description()
        self.print_subcommands()
        self.print_options()

        if classes:
            help_classes = self._help_classes
            if help_classes:
                print("Class parameters")
                print("----------------")
                print()
                for p in wrap_paragraphs(self.keyvalue_description):
                    print(p)
                    print()

            for cls in help_classes:
                cls.class_print_help()
                print()
        else:
            print("To see all available configurables, use `--help-all`")
            print()

        self.print_examples()

    def print_description(self):
        """Print the application description."""
        for p in wrap_paragraphs(self.description):
            print(p)
            print()

    def print_examples(self):
        """Print usage and examples.

        This usage string goes at the end of the command line help string
        and should contain examples of the application's usage.
        """
        if self.examples:
            print("Examples")
            print("--------")
            print()
            print(indent(dedent(self.examples.strip())))
            print()

    def print_version(self):
        """Print the version string."""
        print(self.version)

    def update_config(self, config):
        """Fire the traits events when the config is updated."""
        # Save a copy of the current config.
        newconfig = deepcopy(self.config)
        # Merge the new config into the current one.
        newconfig.merge(config)
        # Save the combined config as self.config, which triggers the traits
        # events.
        self.config = newconfig

    @catch_config_error
    def initialize_subcommand(self, subc, argv=None):
        """Initialize a subcommand with argv."""
        subapp, help = self.subcommands.get(subc)

        if isinstance(subapp, string_types):
            subapp = import_item(subapp)

        # clear existing instances
        self.__class__.clear_instance()
        # instantiate
        self.subapp = subapp.instance(config=self.config)
        # and initialize subapp
        self.subapp.initialize(argv)

    def flatten_flags(self):
        """flatten flags and aliases, so cl-args override as expected.
        
        This prevents issues such as an alias pointing to InteractiveShell,
        but a config file setting the same trait in TerminalInteraciveShell
        getting inappropriate priority over the command-line arg.

        Only aliases with exactly one descendent in the class list
        will be promoted.
        
        """
        # build a tree of classes in our list that inherit from a particular
        # it will be a dict by parent classname of classes in our list
        # that are descendents
        mro_tree = defaultdict(list)
        for cls in self._help_classes:
            clsname = cls.__name__
            for parent in cls.mro()[1:-3]:
                # exclude cls itself and Configurable,HasTraits,object
                mro_tree[parent.__name__].append(clsname)
        # flatten aliases, which have the form:
        # { 'alias' : 'Class.trait' }
        aliases = {}
        for alias, cls_trait in iteritems(self.aliases):
            cls, trait = cls_trait.split('.', 1)
            children = mro_tree[cls]
            if len(children) == 1:
                # exactly one descendent, promote alias
                cls = children[0]
            aliases[alias] = '.'.join([cls, trait])

        # flatten flags, which are of the form:
        # { 'key' : ({'Cls' : {'trait' : value}}, 'help')}
        flags = {}
        for key, (flagdict, help) in iteritems(self.flags):
            newflag = {}
            for cls, subdict in iteritems(flagdict):
                children = mro_tree[cls]
                # exactly one descendent, promote flag section
                if len(children) == 1:
                    cls = children[0]
                newflag[cls] = subdict
            flags[key] = (newflag, help)
        return flags, aliases

    @catch_config_error
    def parse_command_line(self, argv=None):
        """Parse the command line arguments."""
        argv = sys.argv[1:] if argv is None else argv
        self.argv = [py3compat.cast_unicode(arg) for arg in argv]

        if argv and argv[0] == 'help':
            # turn `ipython help notebook` into `ipython notebook -h`
            argv = argv[1:] + ['-h']

        if self.subcommands and len(argv) > 0:
            # we have subcommands, and one may have been specified
            subc, subargv = argv[0], argv[1:]
            if re.match(r'^\w(\-?\w)*$', subc) and subc in self.subcommands:
                # it's a subcommand, and *not* a flag or class parameter
                return self.initialize_subcommand(subc, subargv)

        # Arguments after a '--' argument are for the script IPython may be
        # about to run, not IPython iteslf. For arguments parsed here (help and
        # version), we want to only search the arguments up to the first
        # occurrence of '--', which we're calling interpreted_argv.
        try:
            interpreted_argv = argv[:argv.index('--')]
        except ValueError:
            interpreted_argv = argv

        if any(x in interpreted_argv for x in ('-h', '--help-all', '--help')):
            self.print_help('--help-all' in interpreted_argv)
            self.exit(0)

        if '--version' in interpreted_argv or '-V' in interpreted_argv:
            self.print_version()
            self.exit(0)

        # flatten flags&aliases, so cl-args get appropriate priority:
        flags, aliases = self.flatten_flags()
        loader = KVArgParseConfigLoader(argv=argv,
                                        aliases=aliases,
                                        flags=flags,
                                        log=self.log)
        config = loader.load_config()
        self.update_config(config)
        # store unparsed args in extra_args
        self.extra_args = loader.extra_args

    @classmethod
    def _load_config_files(cls, basefilename, path=None, log=None):
        """Load config files (py,json) by filename and path.

        yield each config object in turn.
        """

        if not isinstance(path, list):
            path = [path]
        for path in path[::-1]:
            # path list is in descending priority order, so load files backwards:
            pyloader = PyFileConfigLoader(basefilename + '.py',
                                          path=path,
                                          log=log)
            jsonloader = JSONFileConfigLoader(basefilename + '.json',
                                              path=path,
                                              log=log)
            config = None
            for loader in [pyloader, jsonloader]:
                try:
                    config = loader.load_config()
                except ConfigFileNotFound:
                    pass
                except Exception:
                    # try to get the full filename, but it will be empty in the
                    # unlikely event that the error raised before filefind finished
                    filename = loader.full_filename or basefilename
                    # problem while running the file
                    if log:
                        log.error("Exception while loading config file %s",
                                  filename,
                                  exc_info=True)
                else:
                    if log:
                        log.debug("Loaded config file: %s",
                                  loader.full_filename)
                if config:
                    yield config

        raise StopIteration

    @catch_config_error
    def load_config_file(self, filename, path=None):
        """Load config files by filename and path."""
        filename, ext = os.path.splitext(filename)
        loaded = []
        for config in self._load_config_files(filename,
                                              path=path,
                                              log=self.log):
            loaded.append(config)
            self.update_config(config)
        if len(loaded) > 1:
            collisions = loaded[0].collisions(loaded[1])
            if collisions:
                self.log.warn(
                    "Collisions detected in {0}.py and {0}.json config files."
                    " {0}.json has higher priority: {1}".format(
                        filename,
                        json.dumps(collisions, indent=2),
                    ))

    def generate_config_file(self):
        """generate default config file from Configurables"""
        lines = ["# Configuration file for %s." % self.name]
        lines.append('')
        lines.append('c = get_config()')
        lines.append('')
        for cls in self._config_classes:
            lines.append(cls.class_config_section())
        return '\n'.join(lines)

    def exit(self, exit_status=0):
        self.log.debug("Exiting application: %s" % self.name)
        sys.exit(exit_status)

    @classmethod
    def launch_instance(cls, argv=None, **kwargs):
        """Launch a global instance of this Application
        
        If a global instance already exists, this reinitializes and starts it
        """
        app = cls.instance(**kwargs)
        app.initialize(argv)
        app.start()
Esempio n. 21
0
class EngineFactory(RegistrationFactory):
    """IPython engine"""

    # configurables:
    out_stream_factory = Type('IPython.zmq.iostream.OutStream',
                              config=True,
                              help="""The OutStream for handling stdout/err.
        Typically 'IPython.zmq.iostream.OutStream'""")
    display_hook_factory = Type('IPython.zmq.displayhook.ZMQDisplayHook',
                                config=True,
                                help="""The class for handling displayhook.
        Typically 'IPython.zmq.displayhook.ZMQDisplayHook'""")
    location = Unicode(
        config=True,
        help="""The location (an IP address) of the controller.  This is
        used for disambiguating URLs, to determine whether
        loopback should be used to connect or the public address.""")
    timeout = CFloat(
        2,
        config=True,
        help="""The time (in seconds) to wait for the Controller to respond
        to registration requests before giving up.""")

    # not configurable:
    user_ns = Dict()
    id = Int(allow_none=True)
    registrar = Instance('zmq.eventloop.zmqstream.ZMQStream')
    kernel = Instance(Kernel)

    bident = CBytes()
    ident = Unicode()

    def _ident_changed(self, name, old, new):
        self.bident = asbytes(new)

    def __init__(self, **kwargs):
        super(EngineFactory, self).__init__(**kwargs)
        self.ident = self.session.session
        ctx = self.context

        reg = ctx.socket(zmq.XREQ)
        reg.setsockopt(zmq.IDENTITY, self.bident)
        reg.connect(self.url)
        self.registrar = zmqstream.ZMQStream(reg, self.loop)

    def register(self):
        """send the registration_request"""

        self.log.info("Registering with controller at %s" % self.url)
        content = dict(queue=self.ident,
                       heartbeat=self.ident,
                       control=self.ident)
        self.registrar.on_recv(self.complete_registration)
        # print (self.session.key)
        self.session.send(self.registrar,
                          "registration_request",
                          content=content)

    def complete_registration(self, msg):
        # print msg
        self._abort_dc.stop()
        ctx = self.context
        loop = self.loop
        identity = self.bident
        idents, msg = self.session.feed_identities(msg)
        msg = Message(self.session.unpack_message(msg))

        if msg.content.status == 'ok':
            self.id = int(msg.content.id)

            # create Shell Streams (MUX, Task, etc.):
            queue_addr = msg.content.mux
            shell_addrs = [str(queue_addr)]
            task_addr = msg.content.task
            if task_addr:
                shell_addrs.append(str(task_addr))

            # Uncomment this to go back to two-socket model
            # shell_streams = []
            # for addr in shell_addrs:
            #     stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
            #     stream.setsockopt(zmq.IDENTITY, identity)
            #     stream.connect(disambiguate_url(addr, self.location))
            #     shell_streams.append(stream)

            # Now use only one shell stream for mux and tasks
            stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
            stream.setsockopt(zmq.IDENTITY, identity)
            shell_streams = [stream]
            for addr in shell_addrs:
                stream.connect(disambiguate_url(addr, self.location))
            # end single stream-socket

            # control stream:
            control_addr = str(msg.content.control)
            control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
            control_stream.setsockopt(zmq.IDENTITY, identity)
            control_stream.connect(
                disambiguate_url(control_addr, self.location))

            # create iopub stream:
            iopub_addr = msg.content.iopub
            iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
            iopub_stream.setsockopt(zmq.IDENTITY, identity)
            iopub_stream.connect(disambiguate_url(iopub_addr, self.location))

            # launch heartbeat
            hb_addrs = msg.content.heartbeat
            # print (hb_addrs)

            # # Redirect input streams and set a display hook.
            if self.out_stream_factory:
                sys.stdout = self.out_stream_factory(self.session,
                                                     iopub_stream, u'stdout')
                sys.stdout.topic = 'engine.%i.stdout' % self.id
                sys.stderr = self.out_stream_factory(self.session,
                                                     iopub_stream, u'stderr')
                sys.stderr.topic = 'engine.%i.stderr' % self.id
            if self.display_hook_factory:
                sys.displayhook = self.display_hook_factory(
                    self.session, iopub_stream)
                sys.displayhook.topic = 'engine.%i.pyout' % self.id

            self.kernel = Kernel(config=self.config,
                                 int_id=self.id,
                                 ident=self.ident,
                                 session=self.session,
                                 control_stream=control_stream,
                                 shell_streams=shell_streams,
                                 iopub_stream=iopub_stream,
                                 loop=loop,
                                 user_ns=self.user_ns,
                                 log=self.log)
            self.kernel.start()
            hb_addrs = [
                disambiguate_url(addr, self.location) for addr in hb_addrs
            ]
            heart = Heart(*map(str, hb_addrs), heart_id=identity)
            heart.start()

        else:
            self.log.fatal("Registration Failed: %s" % msg)
            raise Exception("Registration Failed: %s" % msg)

        self.log.info("Completed registration with id %i" % self.id)

    def abort(self):
        self.log.fatal("Registration timed out after %.1f seconds" %
                       self.timeout)
        self.session.send(self.registrar,
                          "unregistration_request",
                          content=dict(id=self.id))
        time.sleep(1)
        sys.exit(255)

    def start(self):
        dc = ioloop.DelayedCallback(self.register, 0, self.loop)
        dc.start()
        self._abort_dc = ioloop.DelayedCallback(self.abort,
                                                self.timeout * 1000, self.loop)
        self._abort_dc.start()
Esempio n. 22
0
class NotebookManager(LoggingConfigurable):

    # Todo:
    # The notebook_dir attribute is used to mean a couple of different things:
    # 1. Where the notebooks are stored if FileNotebookManager is used.
    # 2. The cwd of the kernel for a project.
    # Right now we use this attribute in a number of different places and
    # we are going to have to disentangle all of this.
    notebook_dir = Unicode(py3compat.getcwd(), config=True, help="""
            The directory to use for notebooks.
            """)

    filename_ext = Unicode(u'.ipynb')
    
    notary = Instance(sign.NotebookNotary)
    def _notary_default(self):
        return sign.NotebookNotary(parent=self)
    
    def check_and_sign(self, nb, path, name):
        """Check for trusted cells, and sign the notebook.
        
        Called as a part of saving notebooks.
        """
        if self.notary.check_cells(nb):
            self.notary.sign(nb)
        else:
            self.log.warn("Saving untrusted notebook %s/%s", path, name)
    
    def mark_trusted_cells(self, nb, path, name):
        """Mark cells as trusted if the notebook signature matches.
        
        Called as a part of loading notebooks.
        """
        trusted = self.notary.check_signature(nb)
        if not trusted:
            self.log.warn("Notebook %s/%s is not trusted", path, name)
        self.notary.mark_cells(nb, trusted)
    
    def path_exists(self, path):
        """Does the API-style path (directory) actually exist?
        
        Override this method in subclasses.
        
        Parameters
        ----------
        path : string
            The 
        
        Returns
        -------
        exists : bool
            Whether the path does indeed exist.
        """
        raise NotImplementedError
    
    def _notebook_dir_changed(self, name, old, new):
        """Do a bit of validation of the notebook dir."""
        if not os.path.isabs(new):
            # If we receive a non-absolute path, make it absolute.
            self.notebook_dir = os.path.abspath(new)
            return
        if os.path.exists(new) and not os.path.isdir(new):
            raise TraitError("notebook dir %r is not a directory" % new)
        if not os.path.exists(new):
            self.log.info("Creating notebook dir %s", new)
            try:
                os.mkdir(new)
            except:
                raise TraitError("Couldn't create notebook dir %r" % new)

    # Main notebook API

    def increment_filename(self, basename, path=''):
        """Increment a notebook filename without the .ipynb to make it unique.
        
        Parameters
        ----------
        basename : unicode
            The name of a notebook without the ``.ipynb`` file extension.
        path : unicode
            The URL path of the notebooks directory
        """
        return basename

    def list_notebooks(self, path=''):
        """Return a list of notebook dicts without content.

        This returns a list of dicts, each of the form::

            dict(notebook_id=notebook,name=name)

        This list of dicts should be sorted by name::

            data = sorted(data, key=lambda item: item['name'])
        """
        raise NotImplementedError('must be implemented in a subclass')

    def get_notebook_model(self, name, path='', content=True):
        """Get the notebook model with or without content."""
        raise NotImplementedError('must be implemented in a subclass')

    def save_notebook_model(self, model, name, path=''):
        """Save the notebook model and return the model with no content."""
        raise NotImplementedError('must be implemented in a subclass')

    def update_notebook_model(self, model, name, path=''):
        """Update the notebook model and return the model with no content."""
        raise NotImplementedError('must be implemented in a subclass')

    def delete_notebook_model(self, name, path=''):
        """Delete notebook by name and path."""
        raise NotImplementedError('must be implemented in a subclass')

    def create_notebook_model(self, model=None, path=''):
        """Create a new notebook and return its model with no content."""
        path = path.strip('/')
        if model is None:
            model = {}
        if 'content' not in model:
            metadata = current.new_metadata(name=u'')
            model['content'] = current.new_notebook(metadata=metadata)
        if 'name' not in model:
            model['name'] = self.increment_filename('Untitled', path)
            
        model['path'] = path
        model = self.save_notebook_model(model, model['name'], model['path'])
        return model

    def copy_notebook(self, from_name, to_name=None, path=''):
        """Copy an existing notebook and return its new model.
        
        If to_name not specified, increment `from_name-Copy#.ipynb`.
        """
        path = path.strip('/')
        model = self.get_notebook_model(from_name, path)
        if not to_name:
            base = os.path.splitext(from_name)[0] + '-Copy'
            to_name = self.increment_filename(base, path)
        model['name'] = to_name
        model = self.save_notebook_model(model, to_name, path)
        return model
    
    # Checkpoint-related
    
    def create_checkpoint(self, name, path=''):
        """Create a checkpoint of the current state of a notebook
        
        Returns a checkpoint_id for the new checkpoint.
        """
        raise NotImplementedError("must be implemented in a subclass")
    
    def list_checkpoints(self, name, path=''):
        """Return a list of checkpoints for a given notebook"""
        return []
    
    def restore_checkpoint(self, checkpoint_id, name, path=''):
        """Restore a notebook from one of its checkpoints"""
        raise NotImplementedError("must be implemented in a subclass")

    def delete_checkpoint(self, checkpoint_id, name, path=''):
        """delete a checkpoint for a notebook"""
        raise NotImplementedError("must be implemented in a subclass")
    
    def log_info(self):
        self.log.info(self.info_string())
    
    def info_string(self):
        return "Serving notebooks"
Esempio n. 23
0
 class A(HasTraits):
     inst = Instance(Foo,(),{})
Esempio n. 24
0
class ZMQTerminalInteractiveShell(TerminalInteractiveShell):
    """A subclass of TerminalInteractiveShell that uses the 0MQ kernel"""
    _executing = False
    _execution_state = Unicode('')
    kernel_timeout = Float(
        60,
        config=True,
        help="""Timeout for giving up on a kernel (in seconds).
        
        On first connect and restart, the console tests whether the
        kernel is running and responsive by sending kernel_info_requests.
        This sets the timeout in seconds for how long the kernel can take
        before being presumed dead.
        """)

    image_handler = Enum(('PIL', 'stream', 'tempfile', 'callable'),
                         config=True,
                         help="""
        Handler for image type output.  This is useful, for example,
        when connecting to the kernel in which pylab inline backend is
        activated.  There are four handlers defined.  'PIL': Use
        Python Imaging Library to popup image; 'stream': Use an
        external program to show the image.  Image will be fed into
        the STDIN of the program.  You will need to configure
        `stream_image_handler`; 'tempfile': Use an external program to
        show the image.  Image will be saved in a temporally file and
        the program is called with the temporally file.  You will need
        to configure `tempfile_image_handler`; 'callable': You can set
        any Python callable which is called with the image data.  You
        will need to configure `callable_image_handler`.
        """)

    stream_image_handler = List(config=True,
                                help="""
        Command to invoke an image viewer program when you are using
        'stream' image handler.  This option is a list of string where
        the first element is the command itself and reminders are the
        options for the command.  Raw image data is given as STDIN to
        the program.
        """)

    tempfile_image_handler = List(config=True,
                                  help="""
        Command to invoke an image viewer program when you are using
        'tempfile' image handler.  This option is a list of string
        where the first element is the command itself and reminders
        are the options for the command.  You can use {file} and
        {format} in the string to represent the location of the
        generated image file and image format.
        """)

    callable_image_handler = Any(config=True,
                                 help="""
        Callable object called via 'callable' image handler with one
        argument, `data`, which is `msg["content"]["data"]` where
        `msg` is the message from iopub channel.  For exmaple, you can
        find base64 encoded PNG data as `data['image/png']`.
        """)

    mime_preference = List(
        default_value=['image/png', 'image/jpeg', 'image/svg+xml'],
        config=True,
        allow_none=False,
        help="""
        Preferred object representation MIME type in order.  First
        matched MIME type will be used.
        """)

    manager = Instance('IPython.kernel.KernelManager')
    client = Instance('IPython.kernel.KernelClient')

    def _client_changed(self, name, old, new):
        self.session_id = new.session.session

    session_id = Unicode()

    def init_completer(self):
        """Initialize the completion machinery.

        This creates completion machinery that can be used by client code,
        either interactively in-process (typically triggered by the readline
        library), programatically (such as in test suites) or out-of-prcess
        (typically over the network by remote frontends).
        """
        from IPython.core.completerlib import (module_completer,
                                               magic_run_completer,
                                               cd_completer)

        self.Completer = ZMQCompleter(self, self.client, config=self.config)

        self.set_hook('complete_command', module_completer, str_key='import')
        self.set_hook('complete_command', module_completer, str_key='from')
        self.set_hook('complete_command', magic_run_completer, str_key='%run')
        self.set_hook('complete_command', cd_completer, str_key='%cd')

        # Only configure readline if we truly are using readline.  IPython can
        # do tab-completion over the network, in GUIs, etc, where readline
        # itself may be absent
        if self.has_readline:
            self.set_readline_completer()

    def run_cell(self, cell, store_history=True):
        """Run a complete IPython cell.
        
        Parameters
        ----------
        cell : str
          The code (including IPython code such as %magic functions) to run.
        store_history : bool
          If True, the raw and translated cell will be stored in IPython's
          history. For user code calling back into IPython's machinery, this
          should be set to False.
        """
        if (not cell) or cell.isspace():
            return

        if cell.strip() == 'exit':
            # explicitly handle 'exit' command
            return self.ask_exit()

        # flush stale replies, which could have been ignored, due to missed heartbeats
        while self.client.shell_channel.msg_ready():
            self.client.shell_channel.get_msg()
        # shell_channel.execute takes 'hidden', which is the inverse of store_hist
        msg_id = self.client.shell_channel.execute(cell, not store_history)

        # first thing is wait for any side effects (output, stdin, etc.)
        self._executing = True
        self._execution_state = "busy"
        while self._execution_state != 'idle' and self.client.is_alive():
            try:
                self.handle_stdin_request(msg_id, timeout=0.05)
            except Empty:
                # display intermediate print statements, etc.
                self.handle_iopub(msg_id)
                pass

        # after all of that is done, wait for the execute reply
        while self.client.is_alive():
            try:
                self.handle_execute_reply(msg_id, timeout=0.05)
            except Empty:
                pass
            else:
                break
        self._executing = False

    #-----------------
    # message handlers
    #-----------------

    def handle_execute_reply(self, msg_id, timeout=None):
        msg = self.client.shell_channel.get_msg(block=False, timeout=timeout)
        if msg["parent_header"].get("msg_id", None) == msg_id:

            self.handle_iopub(msg_id)

            content = msg["content"]
            status = content['status']

            if status == 'aborted':
                self.write('Aborted\n')
                return
            elif status == 'ok':
                # print execution payloads as well:
                for item in content["payload"]:
                    text = item.get('text', None)
                    if text:
                        page.page(text)

            elif status == 'error':
                for frame in content["traceback"]:
                    print(frame, file=io.stderr)

            self.execution_count = int(content["execution_count"] + 1)

    def handle_iopub(self, msg_id):
        """ Method to process subscribe channel's messages

           This method consumes and processes messages on the IOPub channel,
           such as stdout, stderr, pyout and status.
           
           It only displays output that is caused by the given msg_id
        """
        while self.client.iopub_channel.msg_ready():
            sub_msg = self.client.iopub_channel.get_msg()
            msg_type = sub_msg['header']['msg_type']
            parent = sub_msg["parent_header"]
            if (not parent) or msg_id == parent['msg_id']:
                if msg_type == 'status':
                    state = self._execution_state = sub_msg["content"][
                        "execution_state"]
                    # idle messages mean an individual sequence is complete,
                    # so break out of consumption to allow other things to take over.
                    if state == 'idle':
                        break

                elif msg_type == 'stream':
                    if sub_msg["content"]["name"] == "stdout":
                        print(sub_msg["content"]["data"],
                              file=io.stdout,
                              end="")
                        io.stdout.flush()
                    elif sub_msg["content"]["name"] == "stderr":
                        print(sub_msg["content"]["data"],
                              file=io.stderr,
                              end="")
                        io.stderr.flush()

                elif msg_type == 'pyout':
                    self.execution_count = int(
                        sub_msg["content"]["execution_count"])
                    format_dict = sub_msg["content"]["data"]
                    self.handle_rich_data(format_dict)
                    # taken from DisplayHook.__call__:
                    hook = self.displayhook
                    hook.start_displayhook()
                    hook.write_output_prompt()
                    hook.write_format_data(format_dict)
                    hook.log_output(format_dict)
                    hook.finish_displayhook()

                elif msg_type == 'display_data':
                    self.handle_rich_data(sub_msg["content"]["data"])

    _imagemime = {
        'image/png': 'png',
        'image/jpeg': 'jpeg',
        'image/svg+xml': 'svg',
    }

    def handle_rich_data(self, data):
        for mime in self.mime_preference:
            if mime in data and mime in self._imagemime:
                self.handle_image(data, mime)
                return

    def handle_image(self, data, mime):
        handler = getattr(self, 'handle_image_{0}'.format(self.image_handler),
                          None)
        if handler:
            handler(data, mime)

    def handle_image_PIL(self, data, mime):
        if mime not in ('image/png', 'image/jpeg'):
            return
        import PIL.Image
        raw = base64.decodestring(data[mime].encode('ascii'))
        img = PIL.Image.open(BytesIO(raw))
        img.show()

    def handle_image_stream(self, data, mime):
        raw = base64.decodestring(data[mime].encode('ascii'))
        imageformat = self._imagemime[mime]
        fmt = dict(format=imageformat)
        args = [s.format(**fmt) for s in self.stream_image_handler]
        with open(os.devnull, 'w') as devnull:
            proc = subprocess.Popen(args,
                                    stdin=subprocess.PIPE,
                                    stdout=devnull,
                                    stderr=devnull)
            proc.communicate(raw)

    def handle_image_tempfile(self, data, mime):
        raw = base64.decodestring(data[mime].encode('ascii'))
        imageformat = self._imagemime[mime]
        filename = 'tmp.{0}'.format(imageformat)
        with nested(NamedFileInTemporaryDirectory(filename),
                    open(os.devnull, 'w')) as (f, devnull):
            f.write(raw)
            f.flush()
            fmt = dict(file=f.name, format=imageformat)
            args = [s.format(**fmt) for s in self.tempfile_image_handler]
            subprocess.call(args, stdout=devnull, stderr=devnull)

    def handle_image_callable(self, data, mime):
        self.callable_image_handler(data)

    def handle_stdin_request(self, msg_id, timeout=0.1):
        """ Method to capture raw_input
        """
        msg_rep = self.client.stdin_channel.get_msg(timeout=timeout)
        # in case any iopub came while we were waiting:
        self.handle_iopub(msg_id)
        if msg_id == msg_rep["parent_header"].get("msg_id"):
            # wrap SIGINT handler
            real_handler = signal.getsignal(signal.SIGINT)

            def double_int(sig, frame):
                # call real handler (forwards sigint to kernel),
                # then raise local interrupt, stopping local raw_input
                real_handler(sig, frame)
                raise KeyboardInterrupt

            signal.signal(signal.SIGINT, double_int)

            try:
                raw_data = raw_input(msg_rep["content"]["prompt"])
            except EOFError:
                # turn EOFError into EOF character
                raw_data = '\x04'
            except KeyboardInterrupt:
                sys.stdout.write('\n')
                return
            finally:
                # restore SIGINT handler
                signal.signal(signal.SIGINT, real_handler)

            # only send stdin reply if there *was not* another request
            # or execution finished while we were reading.
            if not (self.client.stdin_channel.msg_ready()
                    or self.client.shell_channel.msg_ready()):
                self.client.stdin_channel.input(raw_data)

    def mainloop(self, display_banner=False):
        while True:
            try:
                self.interact(display_banner=display_banner)
                #self.interact_with_readline()
                # XXX for testing of a readline-decoupled repl loop, call
                # interact_with_readline above
                break
            except KeyboardInterrupt:
                # this should not be necessary, but KeyboardInterrupt
                # handling seems rather unpredictable...
                self.write("\nKeyboardInterrupt in interact()\n")

    def wait_for_kernel(self, timeout=None):
        """method to wait for a kernel to be ready"""
        tic = time.time()
        self.client.hb_channel.unpause()
        while True:
            msg_id = self.client.kernel_info()
            reply = None
            while True:
                try:
                    reply = self.client.get_shell_msg(timeout=1)
                except Empty:
                    break
                else:
                    if reply['parent_header'].get('msg_id') == msg_id:
                        return True
            if timeout is not None \
                and (time.time() - tic) > timeout \
                and not self.client.hb_channel.is_beating():
                # heart failed
                return False
        return True

    def interact(self, display_banner=None):
        """Closely emulate the interactive Python console."""

        # batch run -> do not interact
        if self.exit_now:
            return

        if display_banner is None:
            display_banner = self.display_banner

        if isinstance(display_banner, basestring):
            self.show_banner(display_banner)
        elif display_banner:
            self.show_banner()

        more = False

        # run a non-empty no-op, so that we don't get a prompt until
        # we know the kernel is ready. This keeps the connection
        # message above the first prompt.
        if not self.wait_for_kernel(self.kernel_timeout):
            error("Kernel did not respond\n")
            return

        if self.has_readline:
            self.readline_startup_hook(self.pre_readline)
            hlen_b4_cell = self.readline.get_current_history_length()
        else:
            hlen_b4_cell = 0
        # exit_now is set by a call to %Exit or %Quit, through the
        # ask_exit callback.

        while not self.exit_now:
            if not self.client.is_alive():
                # kernel died, prompt for action or exit

                action = "restart" if self.manager else "wait for restart"
                ans = self.ask_yes_no("kernel died, %s ([y]/n)?" % action,
                                      default='y')
                if ans:
                    if self.manager:
                        self.manager.restart_kernel(True)
                    self.wait_for_kernel(self.kernel_timeout)
                else:
                    self.exit_now = True
                continue
            try:
                # protect prompt block from KeyboardInterrupt
                # when sitting on ctrl-C
                self.hooks.pre_prompt_hook()
                if more:
                    try:
                        prompt = self.prompt_manager.render('in2')
                    except Exception:
                        self.showtraceback()
                    if self.autoindent:
                        self.rl_do_indent = True

                else:
                    try:
                        prompt = self.separate_in + self.prompt_manager.render(
                            'in')
                    except Exception:
                        self.showtraceback()

                line = self.raw_input(prompt)
                if self.exit_now:
                    # quick exit on sys.std[in|out] close
                    break
                if self.autoindent:
                    self.rl_do_indent = False

            except KeyboardInterrupt:
                #double-guard against keyboardinterrupts during kbdint handling
                try:
                    self.write('\nKeyboardInterrupt\n')
                    source_raw = self.input_splitter.source_raw_reset()[1]
                    hlen_b4_cell = self._replace_rlhist_multiline(
                        source_raw, hlen_b4_cell)
                    more = False
                except KeyboardInterrupt:
                    pass
            except EOFError:
                if self.autoindent:
                    self.rl_do_indent = False
                    if self.has_readline:
                        self.readline_startup_hook(None)
                self.write('\n')
                self.exit()
            except bdb.BdbQuit:
                warn(
                    'The Python debugger has exited with a BdbQuit exception.\n'
                    'Because of how pdb handles the stack, it is impossible\n'
                    'for IPython to properly format this particular exception.\n'
                    'IPython will resume normal operation.')
            except:
                # exceptions here are VERY RARE, but they can be triggered
                # asynchronously by signal handlers, for example.
                self.showtraceback()
            else:
                self.input_splitter.push(line)
                more = self.input_splitter.push_accepts_more()
                if (self.SyntaxTB.last_syntax_error and self.autoedit_syntax):
                    self.edit_syntax_error()
                if not more:
                    source_raw = self.input_splitter.source_raw_reset()[1]
                    hlen_b4_cell = self._replace_rlhist_multiline(
                        source_raw, hlen_b4_cell)
                    self.run_cell(source_raw)

        # Turn off the exit flag, so the mainloop can be restarted if desired
        self.exit_now = False
Esempio n. 25
0
 class B(HasTraits):
     inst = Instance(Bah, args=(10,), kw=dict(d=20))
Esempio n. 26
0
class PrefilterManager(Configurable):
    """Main prefilter component.

    The IPython prefilter is run on all user input before it is run.  The
    prefilter consumes lines of input and produces transformed lines of
    input.

    The iplementation consists of two phases:

    1. Transformers
    2. Checkers and handlers

    Over time, we plan on deprecating the checkers and handlers and doing
    everything in the transformers.

    The transformers are instances of :class:`PrefilterTransformer` and have
    a single method :meth:`transform` that takes a line and returns a
    transformed line.  The transformation can be accomplished using any
    tool, but our current ones use regular expressions for speed.  We also
    ship :mod:`pyparsing` in :mod:`IPython.external` for use in transformers.

    After all the transformers have been run, the line is fed to the checkers,
    which are instances of :class:`PrefilterChecker`.  The line is passed to
    the :meth:`check` method, which either returns `None` or a
    :class:`PrefilterHandler` instance.  If `None` is returned, the other
    checkers are tried.  If an :class:`PrefilterHandler` instance is returned,
    the line is passed to the :meth:`handle` method of the returned
    handler and no further checkers are tried.

    Both transformers and checkers have a `priority` attribute, that determines
    the order in which they are called.  Smaller priorities are tried first.

    Both transformers and checkers also have `enabled` attribute, which is
    a boolean that determines if the instance is used.

    Users or developers can change the priority or enabled attribute of
    transformers or checkers, but they must call the :meth:`sort_checkers`
    or :meth:`sort_transformers` method after changing the priority.
    """

    multi_line_specials = CBool(True, config=True)
    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')

    def __init__(self, shell=None, config=None):
        super(PrefilterManager, self).__init__(shell=shell, config=config)
        self.shell = shell
        self.init_transformers()
        self.init_handlers()
        self.init_checkers()

    #-------------------------------------------------------------------------
    # API for managing transformers
    #-------------------------------------------------------------------------

    def init_transformers(self):
        """Create the default transformers."""
        self._transformers = []
        for transformer_cls in _default_transformers:
            transformer_cls(shell=self.shell,
                            prefilter_manager=self,
                            config=self.config)

    def sort_transformers(self):
        """Sort the transformers by priority.

        This must be called after the priority of a transformer is changed.
        The :meth:`register_transformer` method calls this automatically.
        """
        self._transformers.sort(key=lambda x: x.priority)

    @property
    def transformers(self):
        """Return a list of checkers, sorted by priority."""
        return self._transformers

    def register_transformer(self, transformer):
        """Register a transformer instance."""
        if transformer not in self._transformers:
            self._transformers.append(transformer)
            self.sort_transformers()

    def unregister_transformer(self, transformer):
        """Unregister a transformer instance."""
        if transformer in self._transformers:
            self._transformers.remove(transformer)

    #-------------------------------------------------------------------------
    # API for managing checkers
    #-------------------------------------------------------------------------

    def init_checkers(self):
        """Create the default checkers."""
        self._checkers = []
        for checker in _default_checkers:
            checker(shell=self.shell,
                    prefilter_manager=self,
                    config=self.config)

    def sort_checkers(self):
        """Sort the checkers by priority.

        This must be called after the priority of a checker is changed.
        The :meth:`register_checker` method calls this automatically.
        """
        self._checkers.sort(key=lambda x: x.priority)

    @property
    def checkers(self):
        """Return a list of checkers, sorted by priority."""
        return self._checkers

    def register_checker(self, checker):
        """Register a checker instance."""
        if checker not in self._checkers:
            self._checkers.append(checker)
            self.sort_checkers()

    def unregister_checker(self, checker):
        """Unregister a checker instance."""
        if checker in self._checkers:
            self._checkers.remove(checker)

    #-------------------------------------------------------------------------
    # API for managing checkers
    #-------------------------------------------------------------------------

    def init_handlers(self):
        """Create the default handlers."""
        self._handlers = {}
        self._esc_handlers = {}
        for handler in _default_handlers:
            handler(shell=self.shell,
                    prefilter_manager=self,
                    config=self.config)

    @property
    def handlers(self):
        """Return a dict of all the handlers."""
        return self._handlers

    def register_handler(self, name, handler, esc_strings):
        """Register a handler instance by name with esc_strings."""
        self._handlers[name] = handler
        for esc_str in esc_strings:
            self._esc_handlers[esc_str] = handler

    def unregister_handler(self, name, handler, esc_strings):
        """Unregister a handler instance by name with esc_strings."""
        try:
            del self._handlers[name]
        except KeyError:
            pass
        for esc_str in esc_strings:
            h = self._esc_handlers.get(esc_str)
            if h is handler:
                del self._esc_handlers[esc_str]

    def get_handler_by_name(self, name):
        """Get a handler by its name."""
        return self._handlers.get(name)

    def get_handler_by_esc(self, esc_str):
        """Get a handler by its escape string."""
        return self._esc_handlers.get(esc_str)

    #-------------------------------------------------------------------------
    # Main prefiltering API
    #-------------------------------------------------------------------------

    def prefilter_line_info(self, line_info):
        """Prefilter a line that has been converted to a LineInfo object.

        This implements the checker/handler part of the prefilter pipe.
        """
        # print "prefilter_line_info: ", line_info
        handler = self.find_handler(line_info)
        return handler.handle(line_info)

    def find_handler(self, line_info):
        """Find a handler for the line_info by trying checkers."""
        for checker in self.checkers:
            if checker.enabled:
                handler = checker.check(line_info)
                if handler:
                    return handler
        return self.get_handler_by_name('normal')

    def transform_line(self, line, continue_prompt):
        """Calls the enabled transformers in order of increasing priority."""
        for transformer in self.transformers:
            if transformer.enabled:
                line = transformer.transform(line, continue_prompt)
        return line

    def prefilter_line(self, line, continue_prompt=False):
        """Prefilter a single input line as text.

        This method prefilters a single line of text by calling the
        transformers and then the checkers/handlers.
        """

        # print "prefilter_line: ", line, continue_prompt
        # All handlers *must* return a value, even if it's blank ('').

        # save the line away in case we crash, so the post-mortem handler can
        # record it
        self.shell._last_input_line = line

        if not line:
            # Return immediately on purely empty lines, so that if the user
            # previously typed some whitespace that started a continuation
            # prompt, he can break out of that loop with just an empty line.
            # This is how the default python prompt works.
            return ''

        # At this point, we invoke our transformers.
        if not continue_prompt or (continue_prompt
                                   and self.multi_line_specials):
            line = self.transform_line(line, continue_prompt)

        # Now we compute line_info for the checkers and handlers
        line_info = LineInfo(line, continue_prompt)

        # the input history needs to track even empty lines
        stripped = line.strip()

        normal_handler = self.get_handler_by_name('normal')
        if not stripped:
            if not continue_prompt:
                self.shell.displayhook.prompt_count -= 1

            return normal_handler.handle(line_info)

        # special handlers are only allowed for single line statements
        if continue_prompt and not self.multi_line_specials:
            return normal_handler.handle(line_info)

        prefiltered = self.prefilter_line_info(line_info)
        # print "prefiltered line: %r" % prefiltered
        return prefiltered

    def prefilter_lines(self, lines, continue_prompt=False):
        """Prefilter multiple input lines of text.

        This is the main entry point for prefiltering multiple lines of
        input.  This simply calls :meth:`prefilter_line` for each line of
        input.

        This covers cases where there are multiple lines in the user entry,
        which is the case when the user goes back to a multiline history
        entry and presses enter.
        """
        llines = lines.rstrip('\n').split('\n')
        # We can get multiple lines in one shot, where multiline input 'blends'
        # into one line, in cases like recalling from the readline history
        # buffer.  We need to make sure that in such cases, we correctly
        # communicate downstream which line is first and which are continuation
        # ones.
        if len(llines) > 1:
            out = '\n'.join([
                self.prefilter_line(line, lnum > 0)
                for lnum, line in enumerate(llines)
            ])
        else:
            out = self.prefilter_line(llines[0], continue_prompt)

        return out
Esempio n. 27
0
 class A(HasTraits):
     inst = Instance(Foo, allow_none=False)
Esempio n. 28
0
class SQLiteDB(BaseDB):
    """SQLite3 TaskRecord backend."""

    filename = Unicode(
        'tasks.db',
        config=True,
        help=
        """The filename of the sqlite task database. [default: 'tasks.db']""")
    location = Unicode(
        '',
        config=True,
        help="""The directory containing the sqlite task database.  The default
        is to use the cluster_dir location.""")
    table = Unicode(
        "",
        config=True,
        help=
        """The SQLite Table to use for storing tasks for this session. If unspecified,
        a new table will be created with the Hub's IDENT.  Specifying the table will result
        in tasks from previous sessions being available via Clients' db_query and
        get_result methods.""")

    if sqlite3 is not None:
        _db = Instance('sqlite3.Connection')
    else:
        _db = None
    # the ordered list of column names
    _keys = List([
        'msg_id',
        'header',
        'content',
        'buffers',
        'submitted',
        'client_uuid',
        'engine_uuid',
        'started',
        'completed',
        'resubmitted',
        'received',
        'result_header',
        'result_content',
        'result_buffers',
        'queue',
        'pyin',
        'pyout',
        'pyerr',
        'stdout',
        'stderr',
    ])
    # sqlite datatypes for checking that db is current format
    _types = Dict({
        'msg_id': 'text',
        'header': 'dict text',
        'content': 'dict text',
        'buffers': 'bufs blob',
        'submitted': 'timestamp',
        'client_uuid': 'text',
        'engine_uuid': 'text',
        'started': 'timestamp',
        'completed': 'timestamp',
        'resubmitted': 'text',
        'received': 'timestamp',
        'result_header': 'dict text',
        'result_content': 'dict text',
        'result_buffers': 'bufs blob',
        'queue': 'text',
        'pyin': 'text',
        'pyout': 'text',
        'pyerr': 'text',
        'stdout': 'text',
        'stderr': 'text',
    })

    def __init__(self, **kwargs):
        super(SQLiteDB, self).__init__(**kwargs)
        if sqlite3 is None:
            raise ImportError("SQLiteDB requires sqlite3")
        if not self.table:
            # use session, and prefix _, since starting with # is illegal
            self.table = '_' + self.session.replace('-', '_')
        if not self.location:
            # get current profile
            from IPython.core.application import BaseIPythonApplication
            if BaseIPythonApplication.initialized():
                app = BaseIPythonApplication.instance()
                if app.profile_dir is not None:
                    self.location = app.profile_dir.location
                else:
                    self.location = u'.'
            else:
                self.location = u'.'
        self._init_db()

        # register db commit as 2s periodic callback
        # to prevent clogging pipes
        # assumes we are being run in a zmq ioloop app
        loop = ioloop.IOLoop.instance()
        pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
        pc.start()

    def _defaults(self, keys=None):
        """create an empty record"""
        d = {}
        keys = self._keys if keys is None else keys
        for key in keys:
            d[key] = None
        return d

    def _check_table(self):
        """Ensure that an incorrect table doesn't exist

        If a bad (old) table does exist, return False
        """
        cursor = self._db.execute("PRAGMA table_info(%s)" % self.table)
        lines = cursor.fetchall()
        if not lines:
            # table does not exist
            return True
        types = {}
        keys = []
        for line in lines:
            keys.append(line[1])
            types[line[1]] = line[2]
        if self._keys != keys:
            # key mismatch
            self.log.warn('keys mismatch')
            return False
        for key in self._keys:
            if types[key] != self._types[key]:
                self.log.warn('type mismatch: %s: %s != %s' %
                              (key, types[key], self._types[key]))
                return False
        return True

    def _init_db(self):
        """Connect to the database and get new session number."""
        # register adapters
        sqlite3.register_adapter(dict, _adapt_dict)
        sqlite3.register_converter('dict', _convert_dict)
        sqlite3.register_adapter(list, _adapt_bufs)
        sqlite3.register_converter('bufs', _convert_bufs)
        # connect to the db
        dbfile = os.path.join(self.location, self.filename)
        self._db = sqlite3.connect(
            dbfile,
            detect_types=sqlite3.PARSE_DECLTYPES,
            # isolation_level = None)#,
            cached_statements=64)
        # print dir(self._db)
        first_table = previous_table = self.table
        i = 0
        while not self._check_table():
            i += 1
            self.table = first_table + '_%i' % i
            self.log.warn(
                "Table %s exists and doesn't match db format, trying %s" %
                (previous_table, self.table))
            previous_table = self.table

        self._db.execute("""CREATE TABLE IF NOT EXISTS %s
                (msg_id text PRIMARY KEY,
                header dict text,
                content dict text,
                buffers bufs blob,
                submitted timestamp,
                client_uuid text,
                engine_uuid text,
                started timestamp,
                completed timestamp,
                resubmitted text,
                received timestamp,
                result_header dict text,
                result_content dict text,
                result_buffers bufs blob,
                queue text,
                pyin text,
                pyout text,
                pyerr text,
                stdout text,
                stderr text)
                """ % self.table)
        self._db.commit()

    def _dict_to_list(self, d):
        """turn a mongodb-style record dict into a list."""

        return [d[key] for key in self._keys]

    def _list_to_dict(self, line, keys=None):
        """Inverse of dict_to_list"""
        keys = self._keys if keys is None else keys
        d = self._defaults(keys)
        for key, value in zip(keys, line):
            d[key] = value

        return d

    def _render_expression(self, check):
        """Turn a mongodb-style search dict into an SQL query."""
        expressions = []
        args = []

        skeys = set(check.keys())
        skeys.difference_update(set(self._keys))
        skeys.difference_update(set(['buffers', 'result_buffers']))
        if skeys:
            raise KeyError("Illegal testing key(s): %s" % skeys)

        for name, sub_check in check.iteritems():
            if isinstance(sub_check, dict):
                for test, value in sub_check.iteritems():
                    try:
                        op = operators[test]
                    except KeyError:
                        raise KeyError("Unsupported operator: %r" % test)
                    if isinstance(op, tuple):
                        op, join = op

                    if value is None and op in null_operators:
                        expr = "%s %s" % (name, null_operators[op])
                    else:
                        expr = "%s %s ?" % (name, op)
                        if isinstance(value, (tuple, list)):
                            if op in null_operators and any(
                                [v is None for v in value]):
                                # equality tests don't work with NULL
                                raise ValueError(
                                    "Cannot use %r test with NULL values on SQLite backend"
                                    % test)
                            expr = '( %s )' % (join.join([expr] * len(value)))
                            args.extend(value)
                        else:
                            args.append(value)
                    expressions.append(expr)
            else:
                # it's an equality check
                if sub_check is None:
                    expressions.append("%s IS NULL" % name)
                else:
                    expressions.append("%s = ?" % name)
                    args.append(sub_check)

        expr = " AND ".join(expressions)
        return expr, args

    def add_record(self, msg_id, rec):
        """Add a new Task Record, by msg_id."""
        d = self._defaults()
        d.update(rec)
        d['msg_id'] = msg_id
        line = self._dict_to_list(d)
        tups = '(%s)' % (','.join(['?'] * len(line)))
        self._db.execute("INSERT INTO %s VALUES %s" % (self.table, tups), line)
        # self._db.commit()

    def get_record(self, msg_id):
        """Get a specific Task Record, by msg_id."""
        cursor = self._db.execute(
            """SELECT * FROM %s WHERE msg_id==?""" % self.table, (msg_id, ))
        line = cursor.fetchone()
        if line is None:
            raise KeyError("No such msg: %r" % msg_id)
        return self._list_to_dict(line)

    def update_record(self, msg_id, rec):
        """Update the data in an existing record."""
        query = "UPDATE %s SET " % self.table
        sets = []
        keys = sorted(rec.keys())
        values = []
        for key in keys:
            sets.append('%s = ?' % key)
            values.append(rec[key])
        query += ', '.join(sets)
        query += ' WHERE msg_id == ?'
        values.append(msg_id)
        self._db.execute(query, values)
        # self._db.commit()

    def drop_record(self, msg_id):
        """Remove a record from the DB."""
        self._db.execute("""DELETE FROM %s WHERE msg_id==?""" % self.table,
                         (msg_id, ))
        # self._db.commit()

    def drop_matching_records(self, check):
        """Remove a record from the DB."""
        expr, args = self._render_expression(check)
        query = "DELETE FROM %s WHERE %s" % (self.table, expr)
        self._db.execute(query, args)
        # self._db.commit()

    def find_records(self, check, keys=None):
        """Find records matching a query dict, optionally extracting subset of keys.

        Returns list of matching records.

        Parameters
        ----------

        check: dict
            mongodb-style query argument
        keys: list of strs [optional]
            if specified, the subset of keys to extract.  msg_id will *always* be
            included.
        """
        if keys:
            bad_keys = [key for key in keys if key not in self._keys]
            if bad_keys:
                raise KeyError("Bad record key(s): %s" % bad_keys)

        if keys:
            # ensure msg_id is present and first:
            if 'msg_id' in keys:
                keys.remove('msg_id')
            keys.insert(0, 'msg_id')
            req = ', '.join(keys)
        else:
            req = '*'
        expr, args = self._render_expression(check)
        query = """SELECT %s FROM %s WHERE %s""" % (req, self.table, expr)
        cursor = self._db.execute(query, args)
        matches = cursor.fetchall()
        records = []
        for line in matches:
            rec = self._list_to_dict(line, keys)
            records.append(rec)
        return records

    def get_history(self):
        """get all msg_ids, ordered by time submitted."""
        query = """SELECT msg_id FROM %s ORDER by submitted ASC""" % self.table
        cursor = self._db.execute(query)
        # will be a list of length 1 tuples
        return [tup[0] for tup in cursor.fetchall()]
Esempio n. 29
0
class NoneInstanceListTrait(HasTraits):
    
    value = List(Instance(Foo, allow_none=False))
Esempio n. 30
0
class InstanceListTrait(HasTraits):

    value = List(Instance(__name__ + '.Foo'))