예제 #1
0
class Bar(Configurable):

    b = Integer(0, config=True, help="The integer b.")
    enabled = Bool(True, config=True, help="Enable bar.")
예제 #2
0
class HistoryManager(HistoryAccessor):
    """A class to organize all history-related functionality in one place.
    """
    # Public interface

    # An instance of the IPython shell we are attached to
    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
    # Lists to hold processed and raw history. These start with a blank entry
    # so that we can index them starting from 1
    input_hist_parsed = List([""])
    input_hist_raw = List([""])
    # A list of directories visited during session
    dir_hist = List()

    def _dir_hist_default(self):
        try:
            return [os.getcwd()]
        except OSError:
            return []

    # A dict of output history, keyed with ints from the shell's
    # execution count.
    output_hist = Dict()
    # The text/plain repr of outputs.
    output_hist_reprs = Dict()

    # The number of the current session in the history database
    session_number = Integer()
    # Should we log output to the database? (default no)
    db_log_output = Bool(False, config=True)
    # Write to database every x commands (higher values save disk access & power)
    #  Values of 1 or less effectively disable caching.
    db_cache_size = Integer(0, config=True)
    # The input and output caches
    db_input_cache = List()
    db_output_cache = List()

    # History saving in separate thread
    save_thread = Instance('IPython.core.history.HistorySavingThread')
    try:  # Event is a function returning an instance of _Event...
        save_flag = Instance(threading._Event)
    except AttributeError:  # ...until Python 3.3, when it's a class.
        save_flag = Instance(threading.Event)

    # Private interface
    # Variables used to store the three last inputs from the user.  On each new
    # history update, we populate the user's namespace with these, shifted as
    # necessary.
    _i00 = Unicode('')
    _i = Unicode('')
    _ii = Unicode('')
    _iii = Unicode('')

    # A regex matching all forms of the exit command, so that we don't store
    # them in the history (it's annoying to rewind the first entry and land on
    # an exit call).
    _exit_re = re.compile(r"(exit|quit)(\s*\(.*\))?$")

    def __init__(self, shell=None, config=None, **traits):
        """Create a new history manager associated with a shell instance.
        """
        # We need a pointer back to the shell for various tasks.
        super(HistoryManager, self).__init__(shell=shell,
                                             config=config,
                                             **traits)
        self.save_flag = threading.Event()
        self.db_input_cache_lock = threading.Lock()
        self.db_output_cache_lock = threading.Lock()
        if self.hist_file != ':memory:':
            self.save_thread = HistorySavingThread(self)
            self.save_thread.start()

        self.new_session()

    def _get_hist_file_name(self, profile=None):
        """Get default history file name based on the Shell's profile.
        
        The profile parameter is ignored, but must exist for compatibility with
        the parent class."""
        profile_dir = self.shell.profile_dir.location
        return os.path.join(profile_dir, 'history.sqlite')

    @needs_sqlite
    def new_session(self, conn=None):
        """Get a new session number."""
        if conn is None:
            conn = self.db

        with conn:
            cur = conn.execute(
                """INSERT INTO sessions VALUES (NULL, ?, NULL,
                            NULL, "") """, (datetime.datetime.now(), ))
            self.session_number = cur.lastrowid

    def end_session(self):
        """Close the database session, filling in the end time and line count."""
        self.writeout_cache()
        with self.db:
            self.db.execute(
                """UPDATE sessions SET end=?, num_cmds=? WHERE
                            session==?""",
                (datetime.datetime.now(), len(self.input_hist_parsed) - 1,
                 self.session_number))
        self.session_number = 0

    def name_session(self, name):
        """Give the current session a name in the history database."""
        with self.db:
            self.db.execute("UPDATE sessions SET remark=? WHERE session==?",
                            (name, self.session_number))

    def reset(self, new_session=True):
        """Clear the session history, releasing all object references, and
        optionally open a new session."""
        self.output_hist.clear()
        # The directory history can't be completely empty
        self.dir_hist[:] = [os.getcwd()]

        if new_session:
            if self.session_number:
                self.end_session()
            self.input_hist_parsed[:] = [""]
            self.input_hist_raw[:] = [""]
            self.new_session()

    # ------------------------------
    # Methods for retrieving history
    # ------------------------------
    def _get_range_session(self, start=1, stop=None, raw=True, output=False):
        """Get input and output history from the current session. Called by
        get_range, and takes similar parameters."""
        input_hist = self.input_hist_raw if raw else self.input_hist_parsed

        n = len(input_hist)
        if start < 0:
            start += n
        if not stop or (stop > n):
            stop = n
        elif stop < 0:
            stop += n

        for i in range(start, stop):
            if output:
                line = (input_hist[i], self.output_hist_reprs.get(i))
            else:
                line = input_hist[i]
            yield (0, i, line)

    def get_range(self, session=0, start=1, stop=None, raw=True, output=False):
        """Retrieve input by session.
        
        Parameters
        ----------
        session : int
            Session number to retrieve. The current session is 0, and negative
            numbers count back from current session, so -1 is previous session.
        start : int
            First line to retrieve.
        stop : int
            End of line range (excluded from output itself). If None, retrieve
            to the end of the session.
        raw : bool
            If True, return untranslated input
        output : bool
            If True, attempt to include output. This will be 'real' Python
            objects for the current session, or text reprs from previous
            sessions if db_log_output was enabled at the time. Where no output
            is found, None is used.
            
        Returns
        -------
        An iterator over the desired lines. Each line is a 3-tuple, either
        (session, line, input) if output is False, or
        (session, line, (input, output)) if output is True.
        """
        if session <= 0:
            session += self.session_number
        if session == self.session_number:  # Current session
            return self._get_range_session(start, stop, raw, output)
        return super(HistoryManager, self).get_range(session, start, stop, raw,
                                                     output)

    ## ----------------------------
    ## Methods for storing history:
    ## ----------------------------
    def store_inputs(self, line_num, source, source_raw=None):
        """Store source and raw input in history and create input cache
        variables _i*.

        Parameters
        ----------
        line_num : int
          The prompt number of this input.

        source : str
          Python input.

        source_raw : str, optional
          If given, this is the raw input without any IPython transformations
          applied to it.  If not given, ``source`` is used.
        """
        if source_raw is None:
            source_raw = source
        source = source.rstrip('\n')
        source_raw = source_raw.rstrip('\n')

        # do not store exit/quit commands
        if self._exit_re.match(source_raw.strip()):
            return

        self.input_hist_parsed.append(source)
        self.input_hist_raw.append(source_raw)

        with self.db_input_cache_lock:
            self.db_input_cache.append((line_num, source, source_raw))
            # Trigger to flush cache and write to DB.
            if len(self.db_input_cache) >= self.db_cache_size:
                self.save_flag.set()

        # update the auto _i variables
        self._iii = self._ii
        self._ii = self._i
        self._i = self._i00
        self._i00 = source_raw

        # hackish access to user namespace to create _i1,_i2... dynamically
        new_i = '_i%s' % line_num
        to_main = {
            '_i': self._i,
            '_ii': self._ii,
            '_iii': self._iii,
            new_i: self._i00
        }

        self.shell.push(to_main, interactive=False)

    def store_output(self, line_num):
        """If database output logging is enabled, this saves all the
        outputs from the indicated prompt number to the database. It's
        called by run_cell after code has been executed.

        Parameters
        ----------
        line_num : int
          The line number from which to save outputs
        """
        if (not self.db_log_output) or (line_num
                                        not in self.output_hist_reprs):
            return
        output = self.output_hist_reprs[line_num]

        with self.db_output_cache_lock:
            self.db_output_cache.append((line_num, output))
        if self.db_cache_size <= 1:
            self.save_flag.set()

    def _writeout_input_cache(self, conn):
        with conn:
            for line in self.db_input_cache:
                conn.execute("INSERT INTO history VALUES (?, ?, ?, ?)",
                             (self.session_number, ) + line)

    def _writeout_output_cache(self, conn):
        with conn:
            for line in self.db_output_cache:
                conn.execute("INSERT INTO output_history VALUES (?, ?, ?)",
                             (self.session_number, ) + line)

    @needs_sqlite
    def writeout_cache(self, conn=None):
        """Write any entries in the cache to the database."""
        if conn is None:
            conn = self.db

        with self.db_input_cache_lock:
            try:
                self._writeout_input_cache(conn)
            except sqlite3.IntegrityError:
                self.new_session(conn)
                print("ERROR! Session/line number was not unique in",
                      "database. History logging moved to new session",
                      self.session_number)
                try:
                    # Try writing to the new session. If this fails, don't
                    # recurse
                    self._writeout_input_cache(conn)
                except sqlite3.IntegrityError:
                    pass
            finally:
                self.db_input_cache = []

        with self.db_output_cache_lock:
            try:
                self._writeout_output_cache(conn)
            except sqlite3.IntegrityError:
                print("!! Session/line number for output was not unique",
                      "in database. Output will not be stored.")
            finally:
                self.db_output_cache = []
예제 #3
0
class MyConfigurable(Configurable):
    a = Integer(1, config=True, help="The integer a.")
    b = Float(1.0, config=True, help="The integer b.")
    c = Unicode('no config')
예제 #4
0
class IPKernelApp(BaseIPythonApplication, InteractiveShellApp):
    name='ipkernel'
    aliases = Dict(kernel_aliases)
    flags = Dict(kernel_flags)
    classes = [Kernel, ZMQInteractiveShell, ProfileDir, Session]
    # the kernel class, as an importstring
    kernel_class = DottedObjectName('IPython.kernel.zmq.ipkernel.Kernel', config=True,
    help="""The Kernel subclass to be used.
    
    This should allow easy re-use of the IPKernelApp entry point
    to configure and launch kernels other than IPython's own.
    """)
    kernel = Any()
    poller = Any() # don't restrict this even though current pollers are all Threads
    heartbeat = Instance(Heartbeat)
    session = Instance('IPython.kernel.zmq.session.Session')
    ports = Dict()
    
    # ipkernel doesn't get its own config file
    def _config_file_name_default(self):
        return 'ipython_config.py'
    
    # inherit config file name from parent:
    parent_appname = Unicode(config=True)
    def _parent_appname_changed(self, name, old, new):
        if self.config_file_specified:
            # it was manually specified, ignore
            return
        self.config_file_name = new.replace('-','_') + u'_config.py'
        # don't let this count as specifying the config file
        self.config_file_specified.remove(self.config_file_name)
        
    # connection info:
    transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
    ip = Unicode(config=True,
        help="Set the IP or interface on which the kernel will listen.")
    def _ip_default(self):
        if self.transport == 'ipc':
            if self.connection_file:
                return os.path.splitext(self.abs_connection_file)[0] + '-ipc'
            else:
                return 'kernel-ipc'
        else:
            return localhost()
    
    hb_port = Integer(0, config=True, help="set the heartbeat port [default: random]")
    shell_port = Integer(0, config=True, help="set the shell (ROUTER) port [default: random]")
    iopub_port = Integer(0, config=True, help="set the iopub (PUB) port [default: random]")
    stdin_port = Integer(0, config=True, help="set the stdin (ROUTER) port [default: random]")
    control_port = Integer(0, config=True, help="set the control (ROUTER) port [default: random]")
    connection_file = Unicode('', config=True, 
    help="""JSON file in which to store connection info [default: kernel-<pid>.json]
    
    This file will contain the IP, ports, and authentication key needed to connect
    clients to this kernel. By default, this file will be created in the security dir
    of the current profile, but can be specified by absolute path.
    """)
    @property
    def abs_connection_file(self):
        if os.path.basename(self.connection_file) == self.connection_file:
            return os.path.join(self.profile_dir.security_dir, self.connection_file)
        else:
            return self.connection_file
        

    # streams, etc.
    no_stdout = Bool(False, config=True, help="redirect stdout to the null device")
    no_stderr = Bool(False, config=True, help="redirect stderr to the null device")
    outstream_class = DottedObjectName('IPython.kernel.zmq.iostream.OutStream',
        config=True, help="The importstring for the OutStream factory")
    displayhook_class = DottedObjectName('IPython.kernel.zmq.displayhook.ZMQDisplayHook',
        config=True, help="The importstring for the DisplayHook factory")

    # polling
    parent_handle = Integer(0, config=True,
        help="""kill this process if its parent dies.  On Windows, the argument
        specifies the HANDLE of the parent process, otherwise it is simply boolean.
        """)
    interrupt = Integer(0, config=True,
        help="""ONLY USED ON WINDOWS
        Interrupt this process when the parent is signaled.
        """)

    def init_crash_handler(self):
        # Install minimal exception handling
        sys.excepthook = FormattedTB(mode='Verbose', color_scheme='NoColor',
                                     ostream=sys.__stdout__)

    def init_poller(self):
        if sys.platform == 'win32':
            if self.interrupt or self.parent_handle:
                self.poller = ParentPollerWindows(self.interrupt, self.parent_handle)
        elif self.parent_handle:
            self.poller = ParentPollerUnix()

    def _bind_socket(self, s, port):
        iface = '%s://%s' % (self.transport, self.ip)
        if self.transport == 'tcp':
            if port <= 0:
                port = s.bind_to_random_port(iface)
            else:
                s.bind("tcp://%s:%i" % (self.ip, port))
        elif self.transport == 'ipc':
            if port <= 0:
                port = 1
                path = "%s-%i" % (self.ip, port)
                while os.path.exists(path):
                    port = port + 1
                    path = "%s-%i" % (self.ip, port)
            else:
                path = "%s-%i" % (self.ip, port)
            s.bind("ipc://%s" % path)
        return port

    def load_connection_file(self):
        """load ip/port/hmac config from JSON connection file"""
        try:
            fname = filefind(self.connection_file, ['.', self.profile_dir.security_dir])
        except IOError:
            self.log.debug("Connection file not found: %s", self.connection_file)
            # This means I own it, so I will clean it up:
            atexit.register(self.cleanup_connection_file)
            return
        self.log.debug(u"Loading connection file %s", fname)
        with open(fname) as f:
            s = f.read()
        cfg = json.loads(s)
        self.transport = cfg.get('transport', self.transport)
        if self.ip == self._ip_default() and 'ip' in cfg:
            # not overridden by config or cl_args
            self.ip = cfg['ip']
        for channel in ('hb', 'shell', 'iopub', 'stdin', 'control'):
            name = channel + '_port'
            if getattr(self, name) == 0 and name in cfg:
                # not overridden by config or cl_args
                setattr(self, name, cfg[name])
        if 'key' in cfg:
            self.config.Session.key = str_to_bytes(cfg['key'])
    
    def write_connection_file(self):
        """write connection info to JSON file"""
        cf = self.abs_connection_file
        self.log.debug("Writing connection file: %s", cf)
        write_connection_file(cf, ip=self.ip, key=self.session.key, transport=self.transport,
        shell_port=self.shell_port, stdin_port=self.stdin_port, hb_port=self.hb_port,
        iopub_port=self.iopub_port, control_port=self.control_port)
    
    def cleanup_connection_file(self):
        cf = self.abs_connection_file
        self.log.debug("Cleaning up connection file: %s", cf)
        try:
            os.remove(cf)
        except (IOError, OSError):
            pass
        
        self.cleanup_ipc_files()
    
    def cleanup_ipc_files(self):
        """cleanup ipc files if we wrote them"""
        if self.transport != 'ipc':
            return
        for port in (self.shell_port, self.iopub_port, self.stdin_port, self.hb_port, self.control_port):
            ipcfile = "%s-%i" % (self.ip, port)
            try:
                os.remove(ipcfile)
            except (IOError, OSError):
                pass
    
    def init_connection_file(self):
        if not self.connection_file:
            self.connection_file = "kernel-%s.json"%os.getpid()
        try:
            self.load_connection_file()
        except Exception:
            self.log.error("Failed to load connection file: %r", self.connection_file, exc_info=True)
            self.exit(1)
    
    def init_sockets(self):
        # Create a context, a session, and the kernel sockets.
        self.log.info("Starting the kernel at pid: %i", os.getpid())
        context = zmq.Context.instance()
        # Uncomment this to try closing the context.
        # atexit.register(context.term)

        self.shell_socket = context.socket(zmq.ROUTER)
        self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
        self.log.debug("shell ROUTER Channel on port: %i" % self.shell_port)

        self.iopub_socket = context.socket(zmq.PUB)
        self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port)
        self.log.debug("iopub PUB Channel on port: %i" % self.iopub_port)

        self.stdin_socket = context.socket(zmq.ROUTER)
        self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port)
        self.log.debug("stdin ROUTER Channel on port: %i" % self.stdin_port)

        self.control_socket = context.socket(zmq.ROUTER)
        self.control_port = self._bind_socket(self.control_socket, self.control_port)
        self.log.debug("control ROUTER Channel on port: %i" % self.control_port)
    
    def init_heartbeat(self):
        """start the heart beating"""
        # heartbeat doesn't share context, because it mustn't be blocked
        # by the GIL, which is accessed by libzmq when freeing zero-copy messages
        hb_ctx = zmq.Context()
        self.heartbeat = Heartbeat(hb_ctx, (self.transport, self.ip, self.hb_port))
        self.hb_port = self.heartbeat.port
        self.log.debug("Heartbeat REP Channel on port: %i" % self.hb_port)
        self.heartbeat.start()
    
    def log_connection_info(self):
        """display connection info, and store ports"""
        basename = os.path.basename(self.connection_file)
        if basename == self.connection_file or \
            os.path.dirname(self.connection_file) == self.profile_dir.security_dir:
            # use shortname
            tail = basename
            if self.profile != 'default':
                tail += " --profile %s" % self.profile
        else:
            tail = self.connection_file
        lines = [
            "To connect another client to this kernel, use:",
            "    --existing %s" % tail,
        ]
        # log connection info
        # info-level, so often not shown.
        # frontends should use the %connect_info magic
        # to see the connection info
        for line in lines:
            self.log.info(line)
        # also raw print to the terminal if no parent_handle (`ipython kernel`)
        if not self.parent_handle:
            io.rprint(_ctrl_c_message)
            for line in lines:
                io.rprint(line)

        self.ports = dict(shell=self.shell_port, iopub=self.iopub_port,
                                stdin=self.stdin_port, hb=self.hb_port,
                                control=self.control_port)

    def init_session(self):
        """create our session object"""
        default_secure(self.config)
        self.session = Session(parent=self, username=u'kernel')

    def init_blackhole(self):
        """redirects stdout/stderr to devnull if necessary"""
        if self.no_stdout or self.no_stderr:
            blackhole = open(os.devnull, 'w')
            if self.no_stdout:
                sys.stdout = sys.__stdout__ = blackhole
            if self.no_stderr:
                sys.stderr = sys.__stderr__ = blackhole
    
    def init_io(self):
        """Redirect input streams and set a display hook."""
        if self.outstream_class:
            outstream_factory = import_item(str(self.outstream_class))
            sys.stdout = outstream_factory(self.session, self.iopub_socket, u'stdout')
            sys.stderr = outstream_factory(self.session, self.iopub_socket, u'stderr')
        if self.displayhook_class:
            displayhook_factory = import_item(str(self.displayhook_class))
            sys.displayhook = displayhook_factory(self.session, self.iopub_socket)

    def init_signal(self):
        signal.signal(signal.SIGINT, signal.SIG_IGN)

    def init_kernel(self):
        """Create the Kernel object itself"""
        shell_stream = ZMQStream(self.shell_socket)
        control_stream = ZMQStream(self.control_socket)
        
        kernel_factory = import_item(str(self.kernel_class))

        kernel = kernel_factory(parent=self, session=self.session,
                                shell_streams=[shell_stream, control_stream],
                                iopub_socket=self.iopub_socket,
                                stdin_socket=self.stdin_socket,
                                log=self.log,
                                profile_dir=self.profile_dir,
                                user_ns=self.user_ns,
        )
        kernel.record_ports(self.ports)
        self.kernel = kernel

    def init_gui_pylab(self):
        """Enable GUI event loop integration, taking pylab into account."""

        # Provide a wrapper for :meth:`InteractiveShellApp.init_gui_pylab`
        # to ensure that any exception is printed straight to stderr.
        # Normally _showtraceback associates the reply with an execution,
        # which means frontends will never draw it, as this exception
        # is not associated with any execute request.

        shell = self.shell
        _showtraceback = shell._showtraceback
        try:
            # replace pyerr-sending traceback with stderr
            def print_tb(etype, evalue, stb):
                print ("GUI event loop or pylab initialization failed",
                       file=io.stderr)
                print (shell.InteractiveTB.stb2text(stb), file=io.stderr)
            shell._showtraceback = print_tb
            InteractiveShellApp.init_gui_pylab(self)
        finally:
            shell._showtraceback = _showtraceback

    def init_shell(self):
        self.shell = self.kernel.shell
        self.shell.configurables.append(self)

    @catch_config_error
    def initialize(self, argv=None):
        super(IPKernelApp, self).initialize(argv)
        self.init_blackhole()
        self.init_connection_file()
        self.init_session()
        self.init_poller()
        self.init_sockets()
        self.init_heartbeat()
        # writing/displaying connection info must be *after* init_sockets/heartbeat
        self.log_connection_info()
        self.write_connection_file()
        self.init_io()
        self.init_signal()
        self.init_kernel()
        # shell init steps
        self.init_path()
        self.init_shell()
        self.init_gui_pylab()
        self.init_extensions()
        self.init_code()
        # flush stdout/stderr, so that anything written to these streams during
        # initialization do not get associated with the first execution request
        sys.stdout.flush()
        sys.stderr.flush()

    def start(self):
        if self.poller is not None:
            self.poller.start()
        self.kernel.start()
        try:
            ioloop.IOLoop.instance().start()
        except KeyboardInterrupt:
            pass
예제 #5
0
파일: sign.py 프로젝트: vbraun/ipython
class NotebookNotary(LoggingConfigurable):
    """A class for computing and verifying notebook signatures."""

    profile_dir = Instance("IPython.core.profiledir.ProfileDir",
                           allow_none=True)

    def _profile_dir_default(self):
        from IPython.core.application import BaseIPythonApplication
        app = None
        try:
            if BaseIPythonApplication.initialized():
                app = BaseIPythonApplication.instance()
        except MultipleInstanceError:
            pass
        if app is None:
            # create an app, without the global instance
            app = BaseIPythonApplication()
            app.initialize(argv=[])
        return app.profile_dir

    db_file = Unicode(
        config=True,
        help="""The sqlite file in which to store notebook signatures.
        By default, this will be in your IPython profile.
        You can set it to ':memory:' to disable sqlite writing to the filesystem.
        """)

    def _db_file_default(self):
        if self.profile_dir is None:
            return ':memory:'
        return os.path.join(self.profile_dir.security_dir, u'nbsignatures.db')

    # 64k entries ~ 12MB
    cache_size = Integer(65535,
                         config=True,
                         help="""The number of notebook signatures to cache.
        When the number of signatures exceeds this value,
        the oldest 25% of signatures will be culled.
        """)
    db = Any()

    def _db_default(self):
        if sqlite3 is None:
            self.log.warn("Missing SQLite3, all notebooks will be untrusted!")
            return
        kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES
                      | sqlite3.PARSE_COLNAMES)
        db = sqlite3.connect(self.db_file, **kwargs)
        self.init_db(db)
        return db

    def init_db(self, db):
        db.execute("""
        CREATE TABLE IF NOT EXISTS nbsignatures
        (
            id integer PRIMARY KEY AUTOINCREMENT,
            algorithm text,
            signature text,
            path text,
            last_seen timestamp
        )""")
        db.execute("""
        CREATE INDEX IF NOT EXISTS algosig ON nbsignatures(algorithm, signature)
        """)
        db.commit()

    algorithm = Enum(algorithms,
                     default_value='sha256',
                     config=True,
                     help="""The hashing algorithm used to sign notebooks.""")

    def _algorithm_changed(self, name, old, new):
        self.digestmod = getattr(hashlib, self.algorithm)

    digestmod = Any()

    def _digestmod_default(self):
        return getattr(hashlib, self.algorithm)

    secret_file = Unicode(config=True,
                          help="""The file where the secret key is stored.""")

    def _secret_file_default(self):
        if self.profile_dir is None:
            return ''
        return os.path.join(self.profile_dir.security_dir, 'notebook_secret')

    secret = Bytes(config=True,
                   help="""The secret key with which notebooks are signed.""")

    def _secret_default(self):
        # note : this assumes an Application is running
        if os.path.exists(self.secret_file):
            with io.open(self.secret_file, 'rb') as f:
                return f.read()
        else:
            secret = base64.encodestring(os.urandom(1024))
            self._write_secret_file(secret)
            return secret

    def _write_secret_file(self, secret):
        """write my secret to my secret_file"""
        self.log.info("Writing notebook-signing key to %s", self.secret_file)
        with io.open(self.secret_file, 'wb') as f:
            f.write(secret)
        try:
            os.chmod(self.secret_file, 0o600)
        except OSError:
            self.log.warn("Could not set permissions on %s", self.secret_file)
        return secret

    def compute_signature(self, nb):
        """Compute a notebook's signature
        
        by hashing the entire contents of the notebook via HMAC digest.
        """
        hmac = HMAC(self.secret, digestmod=self.digestmod)
        # don't include the previous hash in the content to hash
        with signature_removed(nb):
            # sign the whole thing
            for b in yield_everything(nb):
                hmac.update(b)

        return hmac.hexdigest()

    def check_signature(self, nb):
        """Check a notebook's stored signature
        
        If a signature is stored in the notebook's metadata,
        a new signature is computed and compared with the stored value.
        
        Returns True if the signature is found and matches, False otherwise.
        
        The following conditions must all be met for a notebook to be trusted:
        - a signature is stored in the form 'scheme:hexdigest'
        - the stored scheme matches the requested scheme
        - the requested scheme is available from hashlib
        - the computed hash from notebook_signature matches the stored hash
        """
        if nb.nbformat < 3:
            return False
        if self.db is None:
            return False
        signature = self.compute_signature(nb)
        r = self.db.execute(
            """SELECT id FROM nbsignatures WHERE
            algorithm = ? AND
            signature = ?;
            """, (self.algorithm, signature)).fetchone()
        if r is None:
            return False
        self.db.execute(
            """UPDATE nbsignatures SET last_seen = ? WHERE
            algorithm = ? AND
            signature = ?;
            """,
            (datetime.utcnow(), self.algorithm, signature),
        )
        self.db.commit()
        return True

    def sign(self, nb):
        """Sign a notebook, indicating that its output is trusted on this machine
        
        Stores hash algorithm and hmac digest in a local database of trusted notebooks.
        """
        if nb.nbformat < 3:
            return
        signature = self.compute_signature(nb)
        self.store_signature(signature, nb)

    def store_signature(self, signature, nb):
        if self.db is None:
            return
        self.db.execute(
            """INSERT OR IGNORE INTO nbsignatures
            (algorithm, signature, last_seen) VALUES (?, ?, ?)""",
            (self.algorithm, signature, datetime.utcnow()))
        self.db.execute(
            """UPDATE nbsignatures SET last_seen = ? WHERE
            algorithm = ? AND
            signature = ?;
            """,
            (datetime.utcnow(), self.algorithm, signature),
        )
        self.db.commit()
        n, = self.db.execute("SELECT Count(*) FROM nbsignatures").fetchone()
        if n > self.cache_size:
            self.cull_db()

    def unsign(self, nb):
        """Ensure that a notebook is untrusted
        
        by removing its signature from the trusted database, if present.
        """
        signature = self.compute_signature(nb)
        self.db.execute(
            """DELETE FROM nbsignatures WHERE
                algorithm = ? AND
                signature = ?;
            """, (self.algorithm, signature))
        self.db.commit()

    def cull_db(self):
        """Cull oldest 25% of the trusted signatures when the size limit is reached"""
        self.db.execute(
            """DELETE FROM nbsignatures WHERE id IN (
            SELECT id FROM nbsignatures ORDER BY last_seen DESC LIMIT -1 OFFSET ?
        );
        """, (max(int(0.75 * self.cache_size), 1), ))

    def mark_cells(self, nb, trusted):
        """Mark cells as trusted if the notebook's signature can be verified
        
        Sets ``cell.metadata.trusted = True | False`` on all code cells,
        depending on whether the stored signature can be verified.
        
        This function is the inverse of check_cells
        """
        if nb.nbformat < 3:
            return

        for cell in yield_code_cells(nb):
            cell['metadata']['trusted'] = trusted

    def _check_cell(self, cell, nbformat_version):
        """Do we trust an individual cell?
        
        Return True if:
        
        - cell is explicitly trusted
        - cell has no potentially unsafe rich output
        
        If a cell has no output, or only simple print statements,
        it will always be trusted.
        """
        # explicitly trusted
        if cell['metadata'].pop("trusted", False):
            return True

        # explicitly safe output
        if nbformat_version >= 4:
            unsafe_output_types = ['execute_result', 'display_data']
            safe_keys = {"output_type", "execution_count", "metadata"}
        else:  # v3
            unsafe_output_types = ['pyout', 'display_data']
            safe_keys = {"output_type", "prompt_number", "metadata"}

        for output in cell['outputs']:
            output_type = output['output_type']
            if output_type in unsafe_output_types:
                # if there are any data keys not in the safe whitelist
                output_keys = set(output)
                if output_keys.difference(safe_keys):
                    return False

        return True

    def check_cells(self, nb):
        """Return whether all code cells are trusted
        
        If there are no code cells, return True.
        
        This function is the inverse of mark_cells.
        """
        if nb.nbformat < 3:
            return False
        trusted = True
        for cell in yield_code_cells(nb):
            # only distrust a cell if it actually has some output to distrust
            if not self._check_cell(cell, nb.nbformat):
                trusted = False

        return trusted
예제 #6
0
파일: execute.py 프로젝트: hack1nt0/CM
class ExecutePreprocessor(Preprocessor):
    """
    Executes all the cells in a notebook
    """
    
    timeout = Integer(30, config=True,
        help="The time to wait (in seconds) for output from executions."
    )

    interrupt_on_timeout = Bool(
        False, config=True,
        help=dedent(
            """
            If execution of a cell times out, interrupt the kernel and 
            continue executing other cells rather than throwing an error and 
            stopping.
            """
        )
    )
    
    extra_arguments = List(Unicode)

    def preprocess(self, nb, resources):
        path = resources.get('metadata', {}).get('path', '')
        if path == '':
            path = None

        from IPython.kernel.manager import start_new_kernel
        kernel_name = nb.metadata.get('kernelspec', {}).get('name', 'python')
        self.log.info("Executing notebook with kernel: %s" % kernel_name)
        self.km, self.kc = start_new_kernel(
            kernel_name=kernel_name,
            extra_arguments=self.extra_arguments,
            stderr=open(os.devnull, 'w'),
            cwd=path)
        self.kc.allow_stdin = False

        try:
            nb, resources = super(ExecutePreprocessor, self).preprocess(nb, resources)
        finally:
            self.kc.stop_channels()
            self.km.shutdown_kernel(now=True)

        return nb, resources

    def preprocess_cell(self, cell, resources, cell_index):
        """
        Apply a transformation on each code cell. See base.py for details.
        """
        if cell.cell_type != 'code':
            return cell, resources
        try:
            outputs = self.run_cell(cell)
        except Exception as e:
            self.log.error("failed to run cell: " + repr(e))
            self.log.error(str(cell.source))
            raise
        cell.outputs = outputs
        return cell, resources

    def run_cell(self, cell):
        msg_id = self.kc.execute(cell.source)
        self.log.debug("Executing cell:\n%s", cell.source)
        # wait for finish, with timeout
        while True:
            try:
                msg = self.kc.shell_channel.get_msg(timeout=self.timeout)
            except Empty:
                self.log.error("Timeout waiting for execute reply")
                if self.interrupt_on_timeout:
                    self.log.error("Interrupting kernel")
                    self.km.interrupt_kernel()
                    break
                else:
                    raise

            if msg['parent_header'].get('msg_id') == msg_id:
                break
            else:
                # not our reply
                continue
        
        outs = []

        while True:
            try:
                msg = self.kc.iopub_channel.get_msg(timeout=self.timeout)
            except Empty:
                self.log.warn("Timeout waiting for IOPub output")
                break
            if msg['parent_header'].get('msg_id') != msg_id:
                # not an output from our execution
                continue

            msg_type = msg['msg_type']
            self.log.debug("output: %s", msg_type)
            content = msg['content']

            # set the prompt number for the input and the output
            if 'execution_count' in content:
                cell['execution_count'] = content['execution_count']

            if msg_type == 'status':
                if content['execution_state'] == 'idle':
                    break
                else:
                    continue
            elif msg_type == 'execute_input':
                continue
            elif msg_type == 'clear_output':
                outs = []
                continue
            elif msg_type.startswith('comm'):
                continue

            try:
                out = output_from_msg(msg)
            except ValueError:
                self.log.error("unhandled iopub msg: " + msg_type)
            else:
                outs.append(out)

        return outs
예제 #7
0
파일: execute.py 프로젝트: tanmaykm/ipython
class ExecutePreprocessor(Preprocessor):
    """
    Executes all the cells in a notebook
    """

    timeout = Integer(
        30,
        config=True,
        help="The time to wait (in seconds) for output from executions.")

    extra_arguments = List(Unicode)

    def preprocess(self, nb, resources):
        from IPython.kernel import run_kernel
        kernel_name = nb.metadata.get('kernelspec', {}).get('name', 'python')
        self.log.info("Executing notebook with kernel: %s" % kernel_name)
        with run_kernel(kernel_name=kernel_name,
                        extra_arguments=self.extra_arguments,
                        stderr=open(os.devnull, 'w')) as kc:
            self.kc = kc
            self.kc.allow_stdin = False
            nb, resources = super(ExecutePreprocessor,
                                  self).preprocess(nb, resources)
        return nb, resources

    def preprocess_cell(self, cell, resources, cell_index):
        """
        Apply a transformation on each code cell. See base.py for details.
        """
        if cell.cell_type != 'code':
            return cell, resources
        try:
            outputs = self.run_cell(cell)
        except Exception as e:
            self.log.error("failed to run cell: " + repr(e))
            self.log.error(str(cell.source))
            raise
        cell.outputs = outputs
        return cell, resources

    def run_cell(self, cell):
        msg_id = self.kc.execute(cell.source)
        self.log.debug("Executing cell:\n%s", cell.source)
        # wait for finish, with timeout
        while True:
            try:
                msg = self.kc.shell_channel.get_msg(timeout=self.timeout)
            except Empty:
                self.log.error("Timeout waiting for execute reply")
                raise
            if msg['parent_header'].get('msg_id') == msg_id:
                break
            else:
                # not our reply
                continue

        outs = []

        while True:
            try:
                msg = self.kc.iopub_channel.get_msg(timeout=self.timeout)
            except Empty:
                self.log.warn("Timeout waiting for IOPub output")
                break
            if msg['parent_header'].get('msg_id') != msg_id:
                # not an output from our execution
                continue

            msg_type = msg['msg_type']
            self.log.debug("output: %s", msg_type)
            content = msg['content']

            # set the prompt number for the input and the output
            if 'execution_count' in content:
                cell['execution_count'] = content['execution_count']

            if msg_type == 'status':
                if content['execution_state'] == 'idle':
                    break
                else:
                    continue
            elif msg_type == 'execute_input':
                continue
            elif msg_type == 'clear_output':
                outs = []
                continue

            try:
                out = output_from_msg(msg)
            except ValueError:
                self.log.error("unhandled iopub msg: " + msg_type)
            else:
                outs.append(out)

        return outs
예제 #8
0
class HeartMonitor(LoggingConfigurable):
    """A basic HeartMonitor class
    pingstream: a PUB stream
    pongstream: an ROUTER stream
    period: the period of the heartbeat in milliseconds"""

    debug = Bool(
        False,
        config=True,
        help="""Whether to include every heartbeat in debugging output.
        
        Has to be set explicitly, because there will be *a lot* of output.
        """)
    period = Integer(
        3000,
        config=True,
        help='The frequency at which the Hub pings the engines for heartbeats '
        '(in ms)',
    )
    max_heartmonitor_misses = Integer(
        10,
        config=True,
        help=
        'Allowed consecutive missed pings from controller Hub to engine before unregistering.',
    )

    pingstream = Instance('zmq.eventloop.zmqstream.ZMQStream')
    pongstream = Instance('zmq.eventloop.zmqstream.ZMQStream')
    loop = Instance('zmq.eventloop.ioloop.IOLoop')

    def _loop_default(self):
        return ioloop.IOLoop.instance()

    # not settable:
    hearts = Set()
    responses = Set()
    on_probation = Dict()
    last_ping = CFloat(0)
    _new_handlers = Set()
    _failure_handlers = Set()
    lifetime = CFloat(0)
    tic = CFloat(0)

    def __init__(self, **kwargs):
        super(HeartMonitor, self).__init__(**kwargs)

        self.pongstream.on_recv(self.handle_pong)

    def start(self):
        self.tic = time.time()
        self.caller = ioloop.PeriodicCallback(self.beat, self.period,
                                              self.loop)
        self.caller.start()

    def add_new_heart_handler(self, handler):
        """add a new handler for new hearts"""
        self.log.debug("heartbeat::new_heart_handler: %s", handler)
        self._new_handlers.add(handler)

    def add_heart_failure_handler(self, handler):
        """add a new handler for heart failure"""
        self.log.debug("heartbeat::new heart failure handler: %s", handler)
        self._failure_handlers.add(handler)

    def beat(self):
        self.pongstream.flush()
        self.last_ping = self.lifetime

        toc = time.time()
        self.lifetime += toc - self.tic
        self.tic = toc
        if self.debug:
            self.log.debug("heartbeat::sending %s", self.lifetime)
        goodhearts = self.hearts.intersection(self.responses)
        missed_beats = self.hearts.difference(goodhearts)
        newhearts = self.responses.difference(goodhearts)
        for heart in newhearts:
            self.handle_new_heart(heart)
        heartfailures, on_probation = self._check_missed(
            missed_beats, self.on_probation, self.hearts)
        for failure in heartfailures:
            self.handle_heart_failure(failure)
        self.on_probation = on_probation
        self.responses = set()
        #print self.on_probation, self.hearts
        # self.log.debug("heartbeat::beat %.3f, %i beating hearts", self.lifetime, len(self.hearts))
        self.pingstream.send(str_to_bytes(str(self.lifetime)))
        # flush stream to force immediate socket send
        self.pingstream.flush()

    def _check_missed(self, missed_beats, on_probation, hearts):
        """Update heartbeats on probation, identifying any that have too many misses.
        """
        failures = []
        new_probation = {}
        for cur_heart in (b for b in missed_beats if b in hearts):
            miss_count = on_probation.get(cur_heart, 0) + 1
            self.log.info("heartbeat::missed %s : %s" %
                          (cur_heart, miss_count))
            if miss_count > self.max_heartmonitor_misses:
                failures.append(cur_heart)
            else:
                new_probation[cur_heart] = miss_count
        return failures, new_probation

    def handle_new_heart(self, heart):
        if self._new_handlers:
            for handler in self._new_handlers:
                handler(heart)
        else:
            self.log.info("heartbeat::yay, got new heart %s!", heart)
        self.hearts.add(heart)

    def handle_heart_failure(self, heart):
        if self._failure_handlers:
            for handler in self._failure_handlers:
                try:
                    handler(heart)
                except Exception as e:
                    self.log.error("heartbeat::Bad Handler! %s",
                                   handler,
                                   exc_info=True)
                    pass
        else:
            self.log.info("heartbeat::Heart %s failed :(", heart)
        self.hearts.remove(heart)

    @log_errors
    def handle_pong(self, msg):
        "a heart just beat"
        current = str_to_bytes(str(self.lifetime))
        last = str_to_bytes(str(self.last_ping))
        if msg[1] == current:
            delta = time.time() - self.tic
            if self.debug:
                self.log.debug("heartbeat::heart %r took %.2f ms to respond",
                               msg[0], 1000 * delta)
            self.responses.add(msg[0])
        elif msg[1] == last:
            delta = time.time() - self.tic + (self.lifetime - self.last_ping)
            self.log.warn(
                "heartbeat::heart %r missed a beat, and took %.2f ms to respond",
                msg[0], 1000 * delta)
            self.responses.add(msg[0])
        else:
            self.log.warn(
                "heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)",
                msg[1], self.lifetime)
예제 #9
0
class EngineFactory(RegistrationFactory):
    """IPython engine"""

    # configurables:
    out_stream_factory = Type('IPython.kernel.zmq.iostream.OutStream',
                              config=True,
                              help="""The OutStream for handling stdout/err.
        Typically 'IPython.kernel.zmq.iostream.OutStream'""")
    display_hook_factory = Type(
        'IPython.kernel.zmq.displayhook.ZMQDisplayHook',
        config=True,
        help="""The class for handling displayhook.
        Typically 'IPython.kernel.zmq.displayhook.ZMQDisplayHook'""")
    location = Unicode(
        config=True,
        help="""The location (an IP address) of the controller.  This is
        used for disambiguating URLs, to determine whether
        loopback should be used to connect or the public address.""")
    timeout = Float(
        5.0,
        config=True,
        help="""The time (in seconds) to wait for the Controller to respond
        to registration requests before giving up.""")
    max_heartbeat_misses = Integer(
        50,
        config=True,
        help="""The maximum number of times a check for the heartbeat ping of a 
        controller can be missed before shutting down the engine.
        
        If set to 0, the check is disabled.""")
    sshserver = Unicode(
        config=True,
        help=
        """The SSH server to use for tunneling connections to the Controller."""
    )
    sshkey = Unicode(
        config=True,
        help=
        """The SSH private key file to use when tunneling connections to the Controller."""
    )
    paramiko = Bool(
        sys.platform == 'win32',
        config=True,
        help="""Whether to use paramiko instead of openssh for tunnels.""")

    @property
    def tunnel_mod(self):
        from zmq.ssh import tunnel
        return tunnel

    # not configurable:
    connection_info = Dict()
    user_ns = Dict()
    id = Integer(allow_none=True)
    registrar = Instance('zmq.eventloop.zmqstream.ZMQStream')
    kernel = Instance(Kernel)
    hb_check_period = Integer()

    # States for the heartbeat monitoring
    # Initial values for monitored and pinged must satisfy "monitored > pinged == False" so that
    # during the first check no "missed" ping is reported. Must be floats for Python 3 compatibility.
    _hb_last_pinged = 0.0
    _hb_last_monitored = 0.0
    _hb_missed_beats = 0
    # The zmq Stream which receives the pings from the Heart
    _hb_listener = None

    bident = CBytes()
    ident = Unicode()

    def _ident_changed(self, name, old, new):
        self.bident = cast_bytes(new)

    using_ssh = Bool(False)

    def __init__(self, **kwargs):
        super(EngineFactory, self).__init__(**kwargs)
        self.ident = self.session.session

    def init_connector(self):
        """construct connection function, which handles tunnels."""
        self.using_ssh = bool(self.sshkey or self.sshserver)

        if self.sshkey and not self.sshserver:
            # We are using ssh directly to the controller, tunneling localhost to localhost
            self.sshserver = self.url.split('://')[1].split(':')[0]

        if self.using_ssh:
            if self.tunnel_mod.try_passwordless_ssh(self.sshserver,
                                                    self.sshkey,
                                                    self.paramiko):
                password = False
            else:
                password = getpass("SSH Password for %s: " % self.sshserver)
        else:
            password = False

        def connect(s, url):
            url = disambiguate_url(url, self.location)
            if self.using_ssh:
                self.log.debug("Tunneling connection to %s via %s", url,
                               self.sshserver)
                return self.tunnel_mod.tunnel_connection(
                    s,
                    url,
                    self.sshserver,
                    keyfile=self.sshkey,
                    paramiko=self.paramiko,
                    password=password,
                )
            else:
                return s.connect(url)

        def maybe_tunnel(url):
            """like connect, but don't complete the connection (for use by heartbeat)"""
            url = disambiguate_url(url, self.location)
            if self.using_ssh:
                self.log.debug("Tunneling connection to %s via %s", url,
                               self.sshserver)
                url, tunnelobj = self.tunnel_mod.open_tunnel(
                    url,
                    self.sshserver,
                    keyfile=self.sshkey,
                    paramiko=self.paramiko,
                    password=password,
                )
            return str(url)

        return connect, maybe_tunnel

    def register(self):
        """send the registration_request"""

        self.log.info("Registering with controller at %s" % self.url)
        ctx = self.context
        connect, maybe_tunnel = self.init_connector()
        reg = ctx.socket(zmq.DEALER)
        reg.setsockopt(zmq.IDENTITY, self.bident)
        connect(reg, self.url)
        self.registrar = zmqstream.ZMQStream(reg, self.loop)

        content = dict(uuid=self.ident)
        self.registrar.on_recv(
            lambda msg: self.complete_registration(msg, connect, maybe_tunnel))
        # print (self.session.key)
        self.session.send(self.registrar,
                          "registration_request",
                          content=content)

    def _report_ping(self, msg):
        """Callback for when the heartmonitor.Heart receives a ping"""
        #self.log.debug("Received a ping: %s", msg)
        self._hb_last_pinged = time.time()

    def complete_registration(self, msg, connect, maybe_tunnel):
        # print msg
        self._abort_dc.stop()
        ctx = self.context
        loop = self.loop
        identity = self.bident
        idents, msg = self.session.feed_identities(msg)
        msg = self.session.unserialize(msg)
        content = msg['content']
        info = self.connection_info

        def url(key):
            """get zmq url for given channel"""
            return str(info["interface"] + ":%i" % info[key])

        if content['status'] == 'ok':
            self.id = int(content['id'])

            # launch heartbeat
            # possibly forward hb ports with tunnels
            hb_ping = maybe_tunnel(url('hb_ping'))
            hb_pong = maybe_tunnel(url('hb_pong'))

            hb_monitor = None
            if self.max_heartbeat_misses > 0:
                # Add a monitor socket which will record the last time a ping was seen
                mon = self.context.socket(zmq.SUB)
                mport = mon.bind_to_random_port('tcp://%s' % localhost())
                mon.setsockopt(zmq.SUBSCRIBE, b"")
                self._hb_listener = zmqstream.ZMQStream(mon, self.loop)
                self._hb_listener.on_recv(self._report_ping)

                hb_monitor = "tcp://%s:%i" % (localhost(), mport)

            heart = Heart(hb_ping, hb_pong, hb_monitor, heart_id=identity)
            heart.start()

            # create Shell Connections (MUX, Task, etc.):
            shell_addrs = url('mux'), url('task')

            # Use only one shell stream for mux and tasks
            stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
            stream.setsockopt(zmq.IDENTITY, identity)
            shell_streams = [stream]
            for addr in shell_addrs:
                connect(stream, addr)

            # control stream:
            control_addr = url('control')
            control_stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
            control_stream.setsockopt(zmq.IDENTITY, identity)
            connect(control_stream, control_addr)

            # create iopub stream:
            iopub_addr = url('iopub')
            iopub_socket = ctx.socket(zmq.PUB)
            iopub_socket.setsockopt(zmq.IDENTITY, identity)
            connect(iopub_socket, iopub_addr)

            # disable history:
            self.config.HistoryManager.hist_file = ':memory:'

            # Redirect input streams and set a display hook.
            if self.out_stream_factory:
                sys.stdout = self.out_stream_factory(self.session,
                                                     iopub_socket, u'stdout')
                sys.stdout.topic = cast_bytes('engine.%i.stdout' % self.id)
                sys.stderr = self.out_stream_factory(self.session,
                                                     iopub_socket, u'stderr')
                sys.stderr.topic = cast_bytes('engine.%i.stderr' % self.id)
            if self.display_hook_factory:
                sys.displayhook = self.display_hook_factory(
                    self.session, iopub_socket)
                sys.displayhook.topic = cast_bytes('engine.%i.execute_result' %
                                                   self.id)

            self.kernel = Kernel(parent=self,
                                 int_id=self.id,
                                 ident=self.ident,
                                 session=self.session,
                                 control_stream=control_stream,
                                 shell_streams=shell_streams,
                                 iopub_socket=iopub_socket,
                                 loop=loop,
                                 user_ns=self.user_ns,
                                 log=self.log)

            self.kernel.shell.display_pub.topic = cast_bytes(
                'engine.%i.displaypub' % self.id)

            # periodically check the heartbeat pings of the controller
            # Should be started here and not in "start()" so that the right period can be taken
            # from the hubs HeartBeatMonitor.period
            if self.max_heartbeat_misses > 0:
                # Use a slightly bigger check period than the hub signal period to not warn unnecessary
                self.hb_check_period = int(content['hb_period']) + 10
                self.log.info(
                    "Starting to monitor the heartbeat signal from the hub every %i ms.",
                    self.hb_check_period)
                self._hb_reporter = ioloop.PeriodicCallback(
                    self._hb_monitor, self.hb_check_period, self.loop)
                self._hb_reporter.start()
            else:
                self.log.info(
                    "Monitoring of the heartbeat signal from the hub is not enabled."
                )

            # FIXME: This is a hack until IPKernelApp and IPEngineApp can be fully merged
            app = IPKernelApp(parent=self,
                              shell=self.kernel.shell,
                              kernel=self.kernel,
                              log=self.log)
            app.init_profile_dir()
            app.init_code()

            self.kernel.start()
        else:
            self.log.fatal("Registration Failed: %s" % msg)
            raise Exception("Registration Failed: %s" % msg)

        self.log.info("Completed registration with id %i" % self.id)

    def abort(self):
        self.log.fatal("Registration timed out after %.1f seconds" %
                       self.timeout)
        if self.url.startswith('127.'):
            self.log.fatal("""
            If the controller and engines are not on the same machine,
            you will have to instruct the controller to listen on an external IP (in ipcontroller_config.py):
                c.HubFactory.ip='*' # for all interfaces, internal and external
                c.HubFactory.ip='192.168.1.101' # or any interface that the engines can see
            or tunnel connections via ssh.
            """)
        self.session.send(self.registrar,
                          "unregistration_request",
                          content=dict(id=self.id))
        time.sleep(1)
        sys.exit(255)

    def _hb_monitor(self):
        """Callback to monitor the heartbeat from the controller"""
        self._hb_listener.flush()
        if self._hb_last_monitored > self._hb_last_pinged:
            self._hb_missed_beats += 1
            self.log.warn(
                "No heartbeat in the last %s ms (%s time(s) in a row).",
                self.hb_check_period, self._hb_missed_beats)
        else:
            #self.log.debug("Heartbeat received (after missing %s beats).", self._hb_missed_beats)
            self._hb_missed_beats = 0

        if self._hb_missed_beats >= self.max_heartbeat_misses:
            self.log.fatal(
                "Maximum number of heartbeats misses reached (%s times %s ms), shutting down.",
                self.max_heartbeat_misses, self.hb_check_period)
            self.session.send(self.registrar,
                              "unregistration_request",
                              content=dict(id=self.id))
            self.loop.stop()

        self._hb_last_monitored = time.time()

    def start(self):
        dc = ioloop.DelayedCallback(self.register, 0, self.loop)
        dc.start()
        self._abort_dc = ioloop.DelayedCallback(self.abort,
                                                self.timeout * 1000, self.loop)
        self._abort_dc.start()
예제 #10
0
파일: launcher.py 프로젝트: waldman/ipython
class BatchSystemLauncher(BaseLauncher):
    """Launch an external process using a batch system.

    This class is designed to work with UNIX batch systems like PBS, LSF,
    GridEngine, etc.  The overall model is that there are different commands
    like qsub, qdel, etc. that handle the starting and stopping of the process.

    This class also has the notion of a batch script. The ``batch_template``
    attribute can be set to a string that is a template for the batch script.
    This template is instantiated using string formatting. Thus the template can
    use {n} fot the number of instances. Subclasses can add additional variables
    to the template dict.
    """

    # Subclasses must fill these in.  See PBSEngineSet
    submit_command = List(
        [''],
        config=True,
        help="The name of the command line program used to submit jobs.")
    delete_command = List(
        [''],
        config=True,
        help="The name of the command line program used to delete jobs.")
    job_id_regexp = CRegExp(
        '',
        config=True,
        help=
        """A regular expression used to get the job id from the output of the
        submit_command.""")
    batch_template = Unicode(
        '',
        config=True,
        help="The string that is the batch script template itself.")
    batch_template_file = Unicode(
        u'', config=True, help="The file that contains the batch template.")
    batch_file_name = Unicode(
        u'batch_script',
        config=True,
        help="The filename of the instantiated batch script.")
    queue = Unicode(u'', config=True, help="The PBS Queue.")

    def _queue_changed(self, name, old, new):
        self.context[name] = new

    n = Integer(1)
    _n_changed = _queue_changed

    # not configurable, override in subclasses
    # PBS Job Array regex
    job_array_regexp = CRegExp('')
    job_array_template = Unicode('')
    # PBS Queue regex
    queue_regexp = CRegExp('')
    queue_template = Unicode('')
    # The default batch template, override in subclasses
    default_template = Unicode('')
    # The full path to the instantiated batch script.
    batch_file = Unicode(u'')
    # the format dict used with batch_template:
    context = Dict()

    def _context_default(self):
        """load the default context with the default values for the basic keys

        because the _trait_changed methods only load the context if they
        are set to something other than the default value.
        """
        return dict(n=1, queue=u'', profile_dir=u'', cluster_id=u'')

    # the Formatter instance for rendering the templates:
    formatter = Instance(EvalFormatter, (), {})

    def find_args(self):
        return self.submit_command + [self.batch_file]

    def __init__(self, work_dir=u'.', config=None, **kwargs):
        super(BatchSystemLauncher, self).__init__(work_dir=work_dir,
                                                  config=config,
                                                  **kwargs)
        self.batch_file = os.path.join(self.work_dir, self.batch_file_name)

    def parse_job_id(self, output):
        """Take the output of the submit command and return the job id."""
        m = self.job_id_regexp.search(output)
        if m is not None:
            job_id = m.group()
        else:
            raise LauncherError("Job id couldn't be determined: %s" % output)
        self.job_id = job_id
        self.log.info('Job submitted with job id: %r', job_id)
        return job_id

    def write_batch_script(self, n):
        """Instantiate and write the batch script to the work_dir."""
        self.n = n
        # first priority is batch_template if set
        if self.batch_template_file and not self.batch_template:
            # second priority is batch_template_file
            with open(self.batch_template_file) as f:
                self.batch_template = f.read()
        if not self.batch_template:
            # third (last) priority is default_template
            self.batch_template = self.default_template

            # add jobarray or queue lines to user-specified template
            # note that this is *only* when user did not specify a template.
            # print self.job_array_regexp.search(self.batch_template)
            if not self.job_array_regexp.search(self.batch_template):
                self.log.debug("adding job array settings to batch script")
                firstline, rest = self.batch_template.split('\n', 1)
                self.batch_template = u'\n'.join(
                    [firstline, self.job_array_template, rest])

            # print self.queue_regexp.search(self.batch_template)
            if self.queue and not self.queue_regexp.search(
                    self.batch_template):
                self.log.debug("adding PBS queue settings to batch script")
                firstline, rest = self.batch_template.split('\n', 1)
                self.batch_template = u'\n'.join(
                    [firstline, self.queue_template, rest])

        script_as_string = self.formatter.format(self.batch_template,
                                                 **self.context)
        self.log.debug('Writing batch script: %s', self.batch_file)

        with open(self.batch_file, 'w') as f:
            f.write(script_as_string)
        os.chmod(self.batch_file, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)

    def start(self, n):
        """Start n copies of the process using a batch system."""
        self.log.debug("Starting %s: %r", self.__class__.__name__, self.args)
        # Here we save profile_dir in the context so they
        # can be used in the batch script template as {profile_dir}
        self.write_batch_script(n)
        output = check_output(self.args, env=os.environ)

        job_id = self.parse_job_id(output)
        self.notify_start(job_id)
        return job_id

    def stop(self):
        output = check_output(self.delete_command + [self.job_id],
                              env=os.environ)
        self.notify_stop(
            dict(job_id=self.job_id,
                 output=output))  # Pass the output of the kill cmd
        return output
예제 #11
0
파일: launcher.py 프로젝트: waldman/ipython
class LocalProcessLauncher(BaseLauncher):
    """Start and stop an external process in an asynchronous manner.

    This will launch the external process with a working directory of
    ``self.work_dir``.
    """

    # This is used to to construct self.args, which is passed to
    # spawnProcess.
    cmd_and_args = List([])
    poll_frequency = Integer(100)  # in ms

    def __init__(self, work_dir=u'.', config=None, **kwargs):
        super(LocalProcessLauncher, self).__init__(work_dir=work_dir,
                                                   config=config,
                                                   **kwargs)
        self.process = None
        self.poller = None

    def find_args(self):
        return self.cmd_and_args

    def start(self):
        self.log.debug("Starting %s: %r", self.__class__.__name__, self.args)
        if self.state == 'before':
            self.process = Popen(self.args,
                                 stdout=PIPE,
                                 stderr=PIPE,
                                 stdin=PIPE,
                                 env=os.environ,
                                 cwd=self.work_dir)
            if WINDOWS:
                self.stdout = forward_read_events(self.process.stdout)
                self.stderr = forward_read_events(self.process.stderr)
            else:
                self.stdout = self.process.stdout.fileno()
                self.stderr = self.process.stderr.fileno()
            self.loop.add_handler(self.stdout, self.handle_stdout,
                                  self.loop.READ)
            self.loop.add_handler(self.stderr, self.handle_stderr,
                                  self.loop.READ)
            self.poller = ioloop.PeriodicCallback(self.poll,
                                                  self.poll_frequency,
                                                  self.loop)
            self.poller.start()
            self.notify_start(self.process.pid)
        else:
            s = 'The process was already started and has state: %r' % self.state
            raise ProcessStateError(s)

    def stop(self):
        return self.interrupt_then_kill()

    def signal(self, sig):
        if self.state == 'running':
            if WINDOWS and sig != SIGINT:
                # use Windows tree-kill for better child cleanup
                check_output(
                    ['taskkill', '-pid',
                     str(self.process.pid), '-t', '-f'])
            else:
                self.process.send_signal(sig)

    def interrupt_then_kill(self, delay=2.0):
        """Send INT, wait a delay and then send KILL."""
        try:
            self.signal(SIGINT)
        except Exception:
            self.log.debug("interrupt failed")
            pass
        self.killer = ioloop.DelayedCallback(lambda: self.signal(SIGKILL),
                                             delay * 1000, self.loop)
        self.killer.start()

    # callbacks, etc:

    def handle_stdout(self, fd, events):
        if WINDOWS:
            line = self.stdout.recv()
        else:
            line = self.process.stdout.readline()
        # a stopped process will be readable but return empty strings
        if line:
            self.log.debug(line[:-1])
        else:
            self.poll()

    def handle_stderr(self, fd, events):
        if WINDOWS:
            line = self.stderr.recv()
        else:
            line = self.process.stderr.readline()
        # a stopped process will be readable but return empty strings
        if line:
            self.log.debug(line[:-1])
        else:
            self.poll()

    def poll(self):
        status = self.process.poll()
        if status is not None:
            self.poller.stop()
            self.loop.remove_handler(self.stdout)
            self.loop.remove_handler(self.stderr)
            self.notify_stop(dict(exit_code=status, pid=self.process.pid))
        return status
예제 #12
0
class DictDB(BaseDB):
    """Basic in-memory dict-based object for saving Task Records.

    This is the first object to present the DB interface
    for logging tasks out of memory.

    The interface is based on MongoDB, so adding a MongoDB
    backend should be straightforward.
    """

    _records = Dict()
    _culled_ids = set()  # set of ids which have been culled
    _buffer_bytes = Integer(0)  # running total of the bytes in the DB

    size_limit = Integer(
        1024**3,
        config=True,
        help="""The maximum total size (in bytes) of the buffers stored in the db
        
        When the db exceeds this size, the oldest records will be culled until
        the total size is under size_limit * (1-cull_fraction).
        default: 1 GB
        """)
    record_limit = Integer(1024,
                           config=True,
                           help="""The maximum number of records in the db
        
        When the history exceeds this size, the first record_limit * cull_fraction
        records will be culled.
        """)
    cull_fraction = Float(
        0.1,
        config=True,
        help=
        """The fraction by which the db should culled when one of the limits is exceeded
        
        In general, the db size will spend most of its time with a size in the range:
        
        [limit * (1-cull_fraction), limit]
        
        for each of size_limit and record_limit.
        """)

    def _match_one(self, rec, tests):
        """Check if a specific record matches tests."""
        for key, test in tests.items():
            if not test(rec.get(key, None)):
                return False
        return True

    def _match(self, check):
        """Find all the matches for a check dict."""
        matches = []
        tests = {}
        for k, v in check.items():
            if isinstance(v, dict):
                tests[k] = CompositeFilter(v)
            else:
                tests[k] = lambda o: o == v

        for rec in self._records.values():
            if self._match_one(rec, tests):
                matches.append(copy(rec))
        return matches

    def _extract_subdict(self, rec, keys):
        """extract subdict of keys"""
        d = {}
        d['msg_id'] = rec['msg_id']
        for key in keys:
            d[key] = rec[key]
        return copy(d)

    # methods for monitoring size / culling history

    def _add_bytes(self, rec):
        for key in ('buffers', 'result_buffers'):
            for buf in rec.get(key) or []:
                self._buffer_bytes += len(buf)

        self._maybe_cull()

    def _drop_bytes(self, rec):
        for key in ('buffers', 'result_buffers'):
            for buf in rec.get(key) or []:
                self._buffer_bytes -= len(buf)

    def _cull_oldest(self, n=1):
        """cull the oldest N records"""
        for msg_id in self.get_history()[:n]:
            self.log.debug("Culling record: %r", msg_id)
            self._culled_ids.add(msg_id)
            self.drop_record(msg_id)

    def _maybe_cull(self):
        # cull by count:
        if len(self._records) > self.record_limit:
            to_cull = int(self.cull_fraction * self.record_limit)
            self.log.info("%i records exceeds limit of %i, culling oldest %i",
                          len(self._records), self.record_limit, to_cull)
            self._cull_oldest(to_cull)

        # cull by size:
        if self._buffer_bytes > self.size_limit:
            limit = self.size_limit * (1 - self.cull_fraction)

            before = self._buffer_bytes
            before_count = len(self._records)
            culled = 0
            while self._buffer_bytes > limit:
                self._cull_oldest(1)
                culled += 1

            self.log.info(
                "%i records with total buffer size %i exceeds limit: %i. Culled oldest %i records.",
                before_count, before, self.size_limit, culled)

    # public API methods:

    def add_record(self, msg_id, rec):
        """Add a new Task Record, by msg_id."""
        if msg_id in self._records:
            raise KeyError("Already have msg_id %r" % (msg_id))
        self._records[msg_id] = rec
        self._add_bytes(rec)
        self._maybe_cull()

    def get_record(self, msg_id):
        """Get a specific Task Record, by msg_id."""
        if msg_id in self._culled_ids:
            raise KeyError("Record %r has been culled for size" % msg_id)
        if not msg_id in self._records:
            raise KeyError("No such msg_id %r" % (msg_id))
        return copy(self._records[msg_id])

    def update_record(self, msg_id, rec):
        """Update the data in an existing record."""
        if msg_id in self._culled_ids:
            raise KeyError("Record %r has been culled for size" % msg_id)
        _rec = self._records[msg_id]
        self._drop_bytes(_rec)
        _rec.update(rec)
        self._add_bytes(_rec)

    def drop_matching_records(self, check):
        """Remove a record from the DB."""
        matches = self._match(check)
        for rec in matches:
            self._drop_bytes(rec)
            del self._records[rec['msg_id']]

    def drop_record(self, msg_id):
        """Remove a record from the DB."""
        rec = self._records[msg_id]
        self._drop_bytes(rec)
        del self._records[msg_id]

    def find_records(self, check, keys=None):
        """Find records matching a query dict, optionally extracting subset of keys.

        Returns dict keyed by msg_id of matching records.

        Parameters
        ----------

        check: dict
            mongodb-style query argument
        keys: list of strs [optional]
            if specified, the subset of keys to extract.  msg_id will *always* be
            included.
        """
        matches = self._match(check)
        if keys:
            return [self._extract_subdict(rec, keys) for rec in matches]
        else:
            return matches

    def get_history(self):
        """get all msg_ids, ordered by time submitted."""
        msg_ids = list(self._records.keys())
        # Remove any that do not have a submitted timestamp.
        # This is extremely unlikely to happen,
        # but it seems to come up in some tests on VMs.
        msg_ids = [
            m for m in msg_ids if self._records[m]['submitted'] is not None
        ]
        return sorted(msg_ids, key=lambda m: self._records[m]['submitted'])
예제 #13
0
class LoadBalancedView(View):
    """An load-balancing View that only executes via the Task scheduler.

    Load-balanced views can be created with the client's `view` method:

    >>> v = client.load_balanced_view()

    or targets can be specified, to restrict the potential destinations:

    >>> v = client.client.load_balanced_view([1,3])

    which would restrict loadbalancing to between engines 1 and 3.

    """

    follow = Any()
    after = Any()
    timeout = CFloat()
    retries = Integer(0)

    _task_scheme = Any()
    _flag_names = List(
        ['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])

    def __init__(self, client=None, socket=None, **flags):
        super(LoadBalancedView, self).__init__(client=client,
                                               socket=socket,
                                               **flags)
        self._task_scheme = client._task_scheme

    def _validate_dependency(self, dep):
        """validate a dependency.

        For use in `set_flags`.
        """
        if dep is None or isinstance(dep,
                                     (basestring, AsyncResult, Dependency)):
            return True
        elif isinstance(dep, (list, set, tuple)):
            for d in dep:
                if not isinstance(d, (basestring, AsyncResult)):
                    return False
        elif isinstance(dep, dict):
            if set(dep.keys()) != set(Dependency().as_dict().keys()):
                return False
            if not isinstance(dep['msg_ids'], list):
                return False
            for d in dep['msg_ids']:
                if not isinstance(d, basestring):
                    return False
        else:
            return False

        return True

    def _render_dependency(self, dep):
        """helper for building jsonable dependencies from various input forms."""
        if isinstance(dep, Dependency):
            return dep.as_dict()
        elif isinstance(dep, AsyncResult):
            return dep.msg_ids
        elif dep is None:
            return []
        else:
            # pass to Dependency constructor
            return list(Dependency(dep))

    def set_flags(self, **kwargs):
        """set my attribute flags by keyword.

        A View is a wrapper for the Client's apply method, but with attributes
        that specify keyword arguments, those attributes can be set by keyword
        argument with this method.

        Parameters
        ----------

        block : bool
            whether to wait for results
        track : bool
            whether to create a MessageTracker to allow the user to
            safely edit after arrays and buffers during non-copying
            sends.

        after : Dependency or collection of msg_ids
            Only for load-balanced execution (targets=None)
            Specify a list of msg_ids as a time-based dependency.
            This job will only be run *after* the dependencies
            have been met.

        follow : Dependency or collection of msg_ids
            Only for load-balanced execution (targets=None)
            Specify a list of msg_ids as a location-based dependency.
            This job will only be run on an engine where this dependency
            is met.

        timeout : float/int or None
            Only for load-balanced execution (targets=None)
            Specify an amount of time (in seconds) for the scheduler to
            wait for dependencies to be met before failing with a
            DependencyTimeout.

        retries : int
            Number of times a task will be retried on failure.
        """

        super(LoadBalancedView, self).set_flags(**kwargs)
        for name in ('follow', 'after'):
            if name in kwargs:
                value = kwargs[name]
                if self._validate_dependency(value):
                    setattr(self, name, value)
                else:
                    raise ValueError("Invalid dependency: %r" % value)
        if 'timeout' in kwargs:
            t = kwargs['timeout']
            if not isinstance(t, (int, long, float, type(None))):
                raise TypeError("Invalid type for timeout: %r" % type(t))
            if t is not None:
                if t < 0:
                    raise ValueError("Invalid timeout: %s" % t)
            self.timeout = t

    @sync_results
    @save_ids
    def _really_apply(self,
                      f,
                      args=None,
                      kwargs=None,
                      block=None,
                      track=None,
                      after=None,
                      follow=None,
                      timeout=None,
                      targets=None,
                      retries=None):
        """calls f(*args, **kwargs) on a remote engine, returning the result.

        This method temporarily sets all of `apply`'s flags for a single call.

        Parameters
        ----------

        f : callable

        args : list [default: empty]

        kwargs : dict [default: empty]

        block : bool [default: self.block]
            whether to block
        track : bool [default: self.track]
            whether to ask zmq to track the message, for safe non-copying sends

        !!!!!! TODO: THE REST HERE  !!!!

        Returns
        -------

        if self.block is False:
            returns AsyncResult
        else:
            returns actual result of f(*args, **kwargs) on the engine(s)
            This will be a list of self.targets is also a list (even length 1), or
            the single result if self.targets is an integer engine id
        """

        # validate whether we can run
        if self._socket.closed:
            msg = "Task farming is disabled"
            if self._task_scheme == 'pure':
                msg += " because the pure ZMQ scheduler cannot handle"
                msg += " disappearing engines."
            raise RuntimeError(msg)

        if self._task_scheme == 'pure':
            # pure zmq scheme doesn't support extra features
            msg = "Pure ZMQ scheduler doesn't support the following flags:"
            "follow, after, retries, targets, timeout"
            if (follow or after or retries or targets or timeout):
                # hard fail on Scheduler flags
                raise RuntimeError(msg)
            if isinstance(f, dependent):
                # soft warn on functional dependencies
                warnings.warn(msg, RuntimeWarning)

        # build args
        args = [] if args is None else args
        kwargs = {} if kwargs is None else kwargs
        block = self.block if block is None else block
        track = self.track if track is None else track
        after = self.after if after is None else after
        retries = self.retries if retries is None else retries
        follow = self.follow if follow is None else follow
        timeout = self.timeout if timeout is None else timeout
        targets = self.targets if targets is None else targets

        if not isinstance(retries, int):
            raise TypeError('retries must be int, not %r' % type(retries))

        if targets is None:
            idents = []
        else:
            idents = self.client._build_targets(targets)[0]
            # ensure *not* bytes
            idents = [ident.decode() for ident in idents]

        after = self._render_dependency(after)
        follow = self._render_dependency(follow)
        subheader = dict(after=after,
                         follow=follow,
                         timeout=timeout,
                         targets=idents,
                         retries=retries)

        msg = self.client.send_apply_request(self._socket,
                                             f,
                                             args,
                                             kwargs,
                                             track=track,
                                             subheader=subheader)
        tracker = None if track is False else msg['tracker']

        ar = AsyncResult(self.client,
                         msg['header']['msg_id'],
                         fname=getname(f),
                         targets=None,
                         tracker=tracker)

        if block:
            try:
                return ar.get()
            except KeyboardInterrupt:
                pass
        return ar

    @spin_after
    @save_ids
    def map(self, f, *sequences, **kwargs):
        """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult

        Parallel version of builtin `map`, load-balanced by this View.

        `block`, and `chunksize` can be specified by keyword only.

        Each `chunksize` elements will be a separate task, and will be
        load-balanced. This lets individual elements be available for iteration
        as soon as they arrive.

        Parameters
        ----------

        f : callable
            function to be mapped
        *sequences: one or more sequences of matching length
            the sequences to be distributed and passed to `f`
        block : bool [default self.block]
            whether to wait for the result or not
        track : bool
            whether to create a MessageTracker to allow the user to
            safely edit after arrays and buffers during non-copying
            sends.
        chunksize : int [default 1]
            how many elements should be in each task.
        ordered : bool [default True]
            Whether the results should be gathered as they arrive, or enforce
            the order of submission.
            
            Only applies when iterating through AsyncMapResult as results arrive.
            Has no effect when block=True.

        Returns
        -------

        if block=False:
            AsyncMapResult
                An object like AsyncResult, but which reassembles the sequence of results
                into a single list. AsyncMapResults can be iterated through before all
                results are complete.
            else:
                the result of map(f,*sequences)

        """

        # default
        block = kwargs.get('block', self.block)
        chunksize = kwargs.get('chunksize', 1)
        ordered = kwargs.get('ordered', True)

        keyset = set(kwargs.keys())
        extra_keys = keyset.difference_update(set(['block', 'chunksize']))
        if extra_keys:
            raise TypeError("Invalid kwargs: %s" % list(extra_keys))

        assert len(sequences) > 0, "must have some sequences to map onto!"

        pf = ParallelFunction(self,
                              f,
                              block=block,
                              chunksize=chunksize,
                              ordered=ordered)
        return pf.map(*sequences)
예제 #14
0
class TerminalInteractiveShell(InteractiveShell):

    autoedit_syntax = CBool(False,
                            config=True,
                            help="auto editing of files with syntax errors.")
    banner = Unicode('')
    banner1 = Unicode(
        default_banner,
        config=True,
        help="""The part of the banner to be printed before the profile""")
    banner2 = Unicode(
        '',
        config=True,
        help="""The part of the banner to be printed after the profile""")
    confirm_exit = CBool(
        True,
        config=True,
        help="""
        Set to confirm when you try to exit IPython with an EOF (Control-D
        in Unix, Control-Z/Enter in Windows). By typing 'exit' or 'quit',
        you can force a direct exit without any confirmation.""",
    )
    # This display_banner only controls whether or not self.show_banner()
    # is called when mainloop/interact are called.  The default is False
    # because for the terminal based application, the banner behavior
    # is controlled by Global.display_banner, which IPythonApp looks at
    # to determine if *it* should call show_banner() by hand or not.
    display_banner = CBool(False)  # This isn't configurable!
    embedded = CBool(False)
    embedded_active = CBool(False)
    editor = Unicode(
        get_default_editor(),
        config=True,
        help="Set the editor used by IPython (default to $EDITOR/vi/notepad).")
    pager = Unicode('less',
                    config=True,
                    help="The shell program to be used for paging.")

    screen_length = Integer(
        0,
        config=True,
        help="""Number of lines of your screen, used to control printing of very
        long strings.  Strings longer than this number of lines will be sent
        through a pager instead of directly printed.  The default value for
        this is 0, which means IPython will auto-detect your screen size every
        time it needs to print certain potentially long strings (this doesn't
        change the behavior of the 'print' keyword, it's only triggered
        internally). If for some reason this isn't working well (it needs
        curses support), specify it yourself. Otherwise don't change the
        default.""",
    )
    term_title = CBool(False,
                       config=True,
                       help="Enable auto setting the terminal title.")

    # In the terminal, GUI control is done via PyOS_InputHook
    from IPython.lib.inputhook import enable_gui
    enable_gui = staticmethod(enable_gui)

    def __init__(self,
                 config=None,
                 ipython_dir=None,
                 profile_dir=None,
                 user_ns=None,
                 user_module=None,
                 custom_exceptions=((), None),
                 usage=None,
                 banner1=None,
                 banner2=None,
                 display_banner=None):

        super(TerminalInteractiveShell,
              self).__init__(config=config,
                             profile_dir=profile_dir,
                             user_ns=user_ns,
                             user_module=user_module,
                             custom_exceptions=custom_exceptions)
        # use os.system instead of utils.process.system by default,
        # because piped system doesn't make sense in the Terminal:
        self.system = self.system_raw

        self.init_term_title()
        self.init_usage(usage)
        self.init_banner(banner1, banner2, display_banner)

    #-------------------------------------------------------------------------
    # Things related to the terminal
    #-------------------------------------------------------------------------

    @property
    def usable_screen_length(self):
        if self.screen_length == 0:
            return 0
        else:
            num_lines_bot = self.separate_in.count('\n') + 1
            return self.screen_length - num_lines_bot

    def init_term_title(self):
        # Enable or disable the terminal title.
        if self.term_title:
            toggle_set_term_title(True)
            set_term_title('IPython: ' + abbrev_cwd())
        else:
            toggle_set_term_title(False)

    #-------------------------------------------------------------------------
    # Things related to aliases
    #-------------------------------------------------------------------------

    def init_alias(self):
        # The parent class defines aliases that can be safely used with any
        # frontend.
        super(TerminalInteractiveShell, self).init_alias()

        # Now define aliases that only make sense on the terminal, because they
        # need direct access to the console in a way that we can't emulate in
        # GUI or web frontend
        if os.name == 'posix':
            aliases = [('clear', 'clear'), ('more', 'more'), ('less', 'less'),
                       ('man', 'man')]
        elif os.name == 'nt':
            aliases = [('cls', 'cls')]

        for name, cmd in aliases:
            self.alias_manager.define_alias(name, cmd)

    #-------------------------------------------------------------------------
    # Things related to the banner and usage
    #-------------------------------------------------------------------------

    def _banner1_changed(self):
        self.compute_banner()

    def _banner2_changed(self):
        self.compute_banner()

    def _term_title_changed(self, name, new_value):
        self.init_term_title()

    def init_banner(self, banner1, banner2, display_banner):
        if banner1 is not None:
            self.banner1 = banner1
        if banner2 is not None:
            self.banner2 = banner2
        if display_banner is not None:
            self.display_banner = display_banner
        self.compute_banner()

    def show_banner(self, banner=None):
        if banner is None:
            banner = self.banner
        self.write(banner)

    def compute_banner(self):
        self.banner = self.banner1
        if self.profile and self.profile != 'default':
            self.banner += '\nIPython profile: %s\n' % self.profile
        if self.banner2:
            self.banner += '\n' + self.banner2

    def init_usage(self, usage=None):
        if usage is None:
            self.usage = interactive_usage
        else:
            self.usage = usage

    #-------------------------------------------------------------------------
    # Mainloop and code execution logic
    #-------------------------------------------------------------------------

    def mainloop(self, display_banner=None):
        """Start the mainloop.

        If an optional banner argument is given, it will override the
        internally created default banner.
        """

        with nested(self.builtin_trap, self.display_trap):

            while 1:
                try:
                    self.interact(display_banner=display_banner)
                    #self.interact_with_readline()
                    # XXX for testing of a readline-decoupled repl loop, call
                    # interact_with_readline above
                    break
                except KeyboardInterrupt:
                    # this should not be necessary, but KeyboardInterrupt
                    # handling seems rather unpredictable...
                    self.write("\nKeyboardInterrupt in interact()\n")

    def _replace_rlhist_multiline(self, source_raw, hlen_before_cell):
        """Store multiple lines as a single entry in history"""

        # do nothing without readline or disabled multiline
        if not self.has_readline or not self.multiline_history:
            return hlen_before_cell

        # windows rl has no remove_history_item
        if not hasattr(self.readline, "remove_history_item"):
            return hlen_before_cell

        # skip empty cells
        if not source_raw.rstrip():
            return hlen_before_cell

        # nothing changed do nothing, e.g. when rl removes consecutive dups
        hlen = self.readline.get_current_history_length()
        if hlen == hlen_before_cell:
            return hlen_before_cell

        for i in range(hlen - hlen_before_cell):
            self.readline.remove_history_item(hlen - i - 1)
        stdin_encoding = get_stream_enc(sys.stdin, 'utf-8')
        self.readline.add_history(
            py3compat.unicode_to_str(source_raw.rstrip(), stdin_encoding))
        return self.readline.get_current_history_length()

    def interact(self, display_banner=None):
        """Closely emulate the interactive Python console."""

        # batch run -> do not interact
        if self.exit_now:
            return

        if display_banner is None:
            display_banner = self.display_banner

        if isinstance(display_banner, basestring):
            self.show_banner(display_banner)
        elif display_banner:
            self.show_banner()

        more = False

        if self.has_readline:
            self.readline_startup_hook(self.pre_readline)
            hlen_b4_cell = self.readline.get_current_history_length()
        else:
            hlen_b4_cell = 0
        # exit_now is set by a call to %Exit or %Quit, through the
        # ask_exit callback.

        while not self.exit_now:
            self.hooks.pre_prompt_hook()
            if more:
                try:
                    prompt = self.prompt_manager.render('in2')
                except:
                    self.showtraceback()
                if self.autoindent:
                    self.rl_do_indent = True

            else:
                try:
                    prompt = self.separate_in + self.prompt_manager.render(
                        'in')
                except:
                    self.showtraceback()
            try:
                line = self.raw_input(prompt)
                if self.exit_now:
                    # quick exit on sys.std[in|out] close
                    break
                if self.autoindent:
                    self.rl_do_indent = False

            except KeyboardInterrupt:
                #double-guard against keyboardinterrupts during kbdint handling
                try:
                    self.write('\nKeyboardInterrupt\n')
                    source_raw = self.input_splitter.source_raw_reset()[1]
                    hlen_b4_cell = \
                        self._replace_rlhist_multiline(source_raw, hlen_b4_cell)
                    more = False
                except KeyboardInterrupt:
                    pass
            except EOFError:
                if self.autoindent:
                    self.rl_do_indent = False
                    if self.has_readline:
                        self.readline_startup_hook(None)
                self.write('\n')
                self.exit()
            except bdb.BdbQuit:
                warn(
                    'The Python debugger has exited with a BdbQuit exception.\n'
                    'Because of how pdb handles the stack, it is impossible\n'
                    'for IPython to properly format this particular exception.\n'
                    'IPython will resume normal operation.')
            except:
                # exceptions here are VERY RARE, but they can be triggered
                # asynchronously by signal handlers, for example.
                self.showtraceback()
            else:
                self.input_splitter.push(line)
                more = self.input_splitter.push_accepts_more()
                if (self.SyntaxTB.last_syntax_error and self.autoedit_syntax):
                    self.edit_syntax_error()
                if not more:
                    source_raw = self.input_splitter.source_raw_reset()[1]
                    self.run_cell(source_raw, store_history=True)
                    hlen_b4_cell = \
                        self._replace_rlhist_multiline(source_raw, hlen_b4_cell)

        # Turn off the exit flag, so the mainloop can be restarted if desired
        self.exit_now = False

    def raw_input(self, prompt=''):
        """Write a prompt and read a line.

        The returned line does not include the trailing newline.
        When the user enters the EOF key sequence, EOFError is raised.

        Optional inputs:

          - prompt(''): a string to be printed to prompt the user.

          - continue_prompt(False): whether this line is the first one or a
          continuation in a sequence of inputs.
        """
        # Code run by the user may have modified the readline completer state.
        # We must ensure that our completer is back in place.

        if self.has_readline:
            self.set_readline_completer()

        try:
            line = py3compat.str_to_unicode(self.raw_input_original(prompt))
        except ValueError:
            warn("\n********\nYou or a %run:ed script called sys.stdin.close()"
                 " or sys.stdout.close()!\nExiting IPython!")
            self.ask_exit()
            return ""

        # Try to be reasonably smart about not re-indenting pasted input more
        # than necessary.  We do this by trimming out the auto-indent initial
        # spaces, if the user's actual input started itself with whitespace.
        if self.autoindent:
            if num_ini_spaces(line) > self.indent_current_nsp:
                line = line[self.indent_current_nsp:]
                self.indent_current_nsp = 0

        return line

    #-------------------------------------------------------------------------
    # Methods to support auto-editing of SyntaxErrors.
    #-------------------------------------------------------------------------

    def edit_syntax_error(self):
        """The bottom half of the syntax error handler called in the main loop.

        Loop until syntax error is fixed or user cancels.
        """

        while self.SyntaxTB.last_syntax_error:
            # copy and clear last_syntax_error
            err = self.SyntaxTB.clear_err_state()
            if not self._should_recompile(err):
                return
            try:
                # may set last_syntax_error again if a SyntaxError is raised
                self.safe_execfile(err.filename, self.user_ns)
            except:
                self.showtraceback()
            else:
                try:
                    f = open(err.filename)
                    try:
                        # This should be inside a display_trap block and I
                        # think it is.
                        sys.displayhook(f.read())
                    finally:
                        f.close()
                except:
                    self.showtraceback()

    def _should_recompile(self, e):
        """Utility routine for edit_syntax_error"""

        if e.filename in ('<ipython console>', '<input>', '<string>',
                          '<console>', '<BackgroundJob compilation>', None):

            return False
        try:
            if (self.autoedit_syntax and not self.ask_yes_no(
                    'Return to editor to correct syntax error? '
                    '[Y/n] ', 'y')):
                return False
        except EOFError:
            return False

        def int0(x):
            try:
                return int(x)
            except TypeError:
                return 0

        # always pass integer line and offset values to editor hook
        try:
            self.hooks.fix_error_editor(e.filename, int0(e.lineno),
                                        int0(e.offset), e.msg)
        except TryNext:
            warn('Could not open editor')
            return False
        return True

    #-------------------------------------------------------------------------
    # Things related to exiting
    #-------------------------------------------------------------------------

    def ask_exit(self):
        """ Ask the shell to exit. Can be overiden and used as a callback. """
        self.exit_now = True

    def exit(self):
        """Handle interactive exit.

        This method calls the ask_exit callback."""
        if self.confirm_exit:
            if self.ask_yes_no('Do you really want to exit ([y]/n)?', 'y'):
                self.ask_exit()
        else:
            self.ask_exit()

    #-------------------------------------------------------------------------
    # Things related to magics
    #-------------------------------------------------------------------------

    def init_magics(self):
        super(TerminalInteractiveShell, self).init_magics()
        self.register_magics(TerminalMagics)

    def showindentationerror(self):
        super(TerminalInteractiveShell, self).showindentationerror()
        print(
            "If you want to paste code into IPython, try the "
            "%paste and %cpaste magic functions.")
예제 #15
0
class RemoteSpawner(Spawner):
    """A Spawner that just uses Popen to start local processes."""

    INTERRUPT_TIMEOUT = Integer(10, config=True, \
                                help="Seconds to wait for process to halt after SIGINT before proceeding to SIGTERM"
                                )
    TERM_TIMEOUT = Integer(5, config=True, \
                           help="Seconds to wait for process to halt after SIGTERM before proceeding to SIGKILL"
                           )
    KILL_TIMEOUT = Integer(5, config=True, \
                           help="Seconds to wait for process to halt after SIGKILL before giving up"
                           )

    server_url = Unicode("localhost", config=True, \
                         help="url of the remote server")
    server_user = Unicode("jupyterhub", config=True, \
                          help="user with passwordless SSH access to the server")
    user_keyfile = Unicode("", config=True, \
                           help="Private key for user")

    channel = Instance(paramiko.client.SSHClient)
    pid = Integer(0)

    def make_preexec_fn(self, name):
        """make preexec fn"""
        return set_user_setuid(name)

    def load_state(self, state):
        """load pid from state"""
        super(RemoteSpawner, self).load_state(state)
        if 'pid' in state:
            self.pid = state['pid']

    def get_state(self):
        """add pid to state"""
        state = super(RemoteSpawner, self).get_state()
        if self.pid:
            state['pid'] = self.pid
        return state

    def clear_state(self):
        """clear pid state"""
        super(RemoteSpawner, self).clear_state()
        self.pid = 0

    def user_env(self, env):
        """get user environment"""
        env['USER'] = self.user.name
        env['HOME'] = pwd.getpwnam(self.user.name).pw_dir
        return env

    def _env_default(self):
        env = super()._env_default()
        return self.user_env(env)

    @gen.coroutine
    def start(self):
        """Start the process"""
        self.user.server.port = random_port()
        self.user.server.ip = self.server_url
        cmd = []
        env = self.env.copy()

        cmd.extend(self.cmd)
        cmd.extend(self.get_args())

        self.log.debug("Env: %s", str(env))
        self.channel = paramiko.SSHClient()
        self.channel.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        self.channel.connect(self.server_url,
                             username=self.server_user,
                             key_filename=self.user_keyfile)
        # self.proc = Popen(cmd, env=env, \
        #    preexec_fn=self.make_preexec_fn(self.user.name),
        #                 )
        self.log.info("Spawning %s", ' '.join(cmd))
        for item in env.items():
            cmd.insert(0, 'export %s="%s";' % item)
        self.pid, stdin, stdout, stderr = execute(self, self.channel,
                                                  ' '.join(cmd),
                                                  self.user.name)
        self.log.info("Process PID is %d" % self.pid)
        #self.log.info("Setting up SSH tunnel")
        #setup_ssh_tunnel(self.user.server.port, self.server_user, self.server_url, self.user_keyfile)
        #self.log.debug("Error %s", ''.join(stderr.readlines()))

    @gen.coroutine
    def poll(self):
        """Poll the process"""
        # for now just assume it is ok
        return None
        ### if we started the process, poll with Popen
        ##if self.proc is not None:
        ##    status = self.proc.poll()
        ##    if status is not None:
        ##        # clear state if the process is done
        ##        self.clear_state()
        ##    return status

        ### if we resumed from stored state,
        ### we don't have the Popen handle anymore, so rely on self.pid

        ##if not self.pid:
        ##    # no pid, not running
        ##    self.clear_state()
        ##    return 0

        ### send signal 0 to check if PID exists
        ### this doesn't work on Windows, but that's okay because we don't support Windows.
        ##alive = yield self._signal(0)
        ##if not alive:
        ##    self.clear_state()
        ##    return 0
        ##else:
        ##    return None

    @gen.coroutine
    def _signal(self, sig):
        """simple implementation of signal

        we can use it when we are using setuid (we are root)"""
        # try:
        #    os.kill(self.pid, sig)
        #except OSError as e:
        #    if e.errno == errno.ESRCH:
        #        return False # process is gone
        #    else:
        #        raise
        return True  # process exists

    @gen.coroutine
    def stop(self, now=False):
        """stop the subprocess

        if `now`, skip waiting for clean shutdown
        """
        return
예제 #16
0
파일: session.py 프로젝트: vigten/ipython
class Session(Configurable):
    """Object for handling serialization and sending of messages.

    The Session object handles building messages and sending them
    with ZMQ sockets or ZMQStream objects.  Objects can communicate with each
    other over the network via Session objects, and only need to work with the
    dict-based IPython message spec. The Session will handle
    serialization/deserialization, security, and metadata.

    Sessions support configurable serialiization via packer/unpacker traits,
    and signing with HMAC digests via the key/keyfile traits.

    Parameters
    ----------

    debug : bool
        whether to trigger extra debugging statements
    packer/unpacker : str : 'json', 'pickle' or import_string
        importstrings for methods to serialize message parts.  If just
        'json' or 'pickle', predefined JSON and pickle packers will be used.
        Otherwise, the entire importstring must be used.

        The functions must accept at least valid JSON input, and output *bytes*.

        For example, to use msgpack:
        packer = 'msgpack.packb', unpacker='msgpack.unpackb'
    pack/unpack : callables
        You can also set the pack/unpack callables for serialization directly.
    session : bytes
        the ID of this Session object.  The default is to generate a new UUID.
    username : unicode
        username added to message headers.  The default is to ask the OS.
    key : bytes
        The key used to initialize an HMAC signature.  If unset, messages
        will not be signed or checked.
    keyfile : filepath
        The file containing a key.  If this is set, `key` will be initialized
        to the contents of the file.

    """

    debug = Bool(False, config=True, help="""Debug output in the Session""")

    packer = DottedObjectName(
        'json',
        config=True,
        help="""The name of the packer for serializing messages.
            Should be one of 'json', 'pickle', or an import name
            for a custom callable serializer.""")

    def _packer_changed(self, name, old, new):
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
            self.unpacker = new
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
            self.unpacker = new
        else:
            self.pack = import_item(str(new))

    unpacker = DottedObjectName(
        'json',
        config=True,
        help="""The name of the unpacker for unserializing messages.
        Only used with custom functions for `packer`.""")

    def _unpacker_changed(self, name, old, new):
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
            self.packer = new
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
            self.packer = new
        else:
            self.unpack = import_item(str(new))

    session = CUnicode(u'',
                       config=True,
                       help="""The UUID identifying this session.""")

    def _session_default(self):
        u = unicode(uuid.uuid4())
        self.bsession = u.encode('ascii')
        return u

    def _session_changed(self, name, old, new):
        self.bsession = self.session.encode('ascii')

    # bsession is the session as bytes
    bsession = CBytes(b'')

    username = Unicode(
        os.environ.get('USER', u'username'),
        config=True,
        help="""Username for the Session. Default is your system username.""")

    metadata = Dict(
        {},
        config=True,
        help=
        """Metadata dictionary, which serves as the default top-level metadata dict for each message."""
    )

    # message signature related traits:

    key = CBytes(b'',
                 config=True,
                 help="""execution key, for extra authentication.""")

    def _key_changed(self, name, old, new):
        if new:
            self.auth = hmac.HMAC(new)
        else:
            self.auth = None

    auth = Instance(hmac.HMAC)
    digest_history = Set()

    keyfile = Unicode('',
                      config=True,
                      help="""path to file containing execution key.""")

    def _keyfile_changed(self, name, old, new):
        with open(new, 'rb') as f:
            self.key = f.read().strip()

    # for protecting against sends from forks
    pid = Integer()

    # serialization traits:

    pack = Any(default_packer)  # the actual packer function

    def _pack_changed(self, name, old, new):
        if not callable(new):
            raise TypeError("packer must be callable, not %s" % type(new))

    unpack = Any(default_unpacker)  # the actual packer function

    def _unpack_changed(self, name, old, new):
        # unpacker is not checked - it is assumed to be
        if not callable(new):
            raise TypeError("unpacker must be callable, not %s" % type(new))

    # thresholds:
    copy_threshold = Integer(
        2**16,
        config=True,
        help=
        "Threshold (in bytes) beyond which a buffer should be sent without copying."
    )
    buffer_threshold = Integer(
        MAX_BYTES,
        config=True,
        help=
        "Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling."
    )
    item_threshold = Integer(
        MAX_ITEMS,
        config=True,
        help=
        """The maximum number of items for a container to be introspected for custom serialization.
        Containers larger than this are pickled outright.
        """)

    def __init__(self, **kwargs):
        """create a Session object

        Parameters
        ----------

        debug : bool
            whether to trigger extra debugging statements
        packer/unpacker : str : 'json', 'pickle' or import_string
            importstrings for methods to serialize message parts.  If just
            'json' or 'pickle', predefined JSON and pickle packers will be used.
            Otherwise, the entire importstring must be used.

            The functions must accept at least valid JSON input, and output
            *bytes*.

            For example, to use msgpack:
            packer = 'msgpack.packb', unpacker='msgpack.unpackb'
        pack/unpack : callables
            You can also set the pack/unpack callables for serialization
            directly.
        session : unicode (must be ascii)
            the ID of this Session object.  The default is to generate a new
            UUID.
        bsession : bytes
            The session as bytes
        username : unicode
            username added to message headers.  The default is to ask the OS.
        key : bytes
            The key used to initialize an HMAC signature.  If unset, messages
            will not be signed or checked.
        keyfile : filepath
            The file containing a key.  If this is set, `key` will be
            initialized to the contents of the file.
        """
        super(Session, self).__init__(**kwargs)
        self._check_packers()
        self.none = self.pack({})
        # ensure self._session_default() if necessary, so bsession is defined:
        self.session
        self.pid = os.getpid()

    @property
    def msg_id(self):
        """always return new uuid"""
        return str(uuid.uuid4())

    def _check_packers(self):
        """check packers for binary data and datetime support."""
        pack = self.pack
        unpack = self.unpack

        # check simple serialization
        msg = dict(a=[1, 'hi'])
        try:
            packed = pack(msg)
        except Exception:
            raise ValueError("packer could not serialize a simple message")

        # ensure packed message is bytes
        if not isinstance(packed, bytes):
            raise ValueError("message packed to %r, but bytes are required" %
                             type(packed))

        # check that unpack is pack's inverse
        try:
            unpacked = unpack(packed)
        except Exception:
            raise ValueError("unpacker could not handle the packer's output")

        # check datetime support
        msg = dict(t=datetime.now())
        try:
            unpacked = unpack(pack(msg))
        except Exception:
            self.pack = lambda o: pack(squash_dates(o))
            self.unpack = lambda s: extract_dates(unpack(s))

    def msg_header(self, msg_type):
        return msg_header(self.msg_id, msg_type, self.username, self.session)

    def msg(self,
            msg_type,
            content=None,
            parent=None,
            header=None,
            metadata=None):
        """Return the nested message dict.

        This format is different from what is sent over the wire. The
        serialize/unserialize methods converts this nested message dict to the wire
        format, which is a list of message parts.
        """
        msg = {}
        header = self.msg_header(msg_type) if header is None else header
        msg['header'] = header
        msg['msg_id'] = header['msg_id']
        msg['msg_type'] = header['msg_type']
        msg['parent_header'] = {} if parent is None else extract_header(parent)
        msg['content'] = {} if content is None else content
        msg['metadata'] = self.metadata.copy()
        if metadata is not None:
            msg['metadata'].update(metadata)
        return msg

    def sign(self, msg_list):
        """Sign a message with HMAC digest. If no auth, return b''.

        Parameters
        ----------
        msg_list : list
            The [p_header,p_parent,p_content] part of the message list.
        """
        if self.auth is None:
            return b''
        h = self.auth.copy()
        for m in msg_list:
            h.update(m)
        return str_to_bytes(h.hexdigest())

    def serialize(self, msg, ident=None):
        """Serialize the message components to bytes.

        This is roughly the inverse of unserialize. The serialize/unserialize
        methods work with full message lists, whereas pack/unpack work with
        the individual message parts in the message list.

        Parameters
        ----------
        msg : dict or Message
            The nexted message dict as returned by the self.msg method.

        Returns
        -------
        msg_list : list
            The list of bytes objects to be sent with the format:
            [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_metadata,p_content,
             buffer1,buffer2,...]. In this list, the p_* entities are
            the packed or serialized versions, so if JSON is used, these
            are utf8 encoded JSON strings.
        """
        content = msg.get('content', {})
        if content is None:
            content = self.none
        elif isinstance(content, dict):
            content = self.pack(content)
        elif isinstance(content, bytes):
            # content is already packed, as in a relayed message
            pass
        elif isinstance(content, unicode):
            # should be bytes, but JSON often spits out unicode
            content = content.encode('utf8')
        else:
            raise TypeError("Content incorrect type: %s" % type(content))

        real_message = [
            self.pack(msg['header']),
            self.pack(msg['parent_header']),
            self.pack(msg['metadata']),
            content,
        ]

        to_send = []

        if isinstance(ident, list):
            # accept list of idents
            to_send.extend(ident)
        elif ident is not None:
            to_send.append(ident)
        to_send.append(DELIM)

        signature = self.sign(real_message)
        to_send.append(signature)

        to_send.extend(real_message)

        return to_send

    def send(self,
             stream,
             msg_or_type,
             content=None,
             parent=None,
             ident=None,
             buffers=None,
             track=False,
             header=None,
             metadata=None):
        """Build and send a message via stream or socket.

        The message format used by this function internally is as follows:

        [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
         buffer1,buffer2,...]

        The serialize/unserialize methods convert the nested message dict into this
        format.

        Parameters
        ----------

        stream : zmq.Socket or ZMQStream
            The socket-like object used to send the data.
        msg_or_type : str or Message/dict
            Normally, msg_or_type will be a msg_type unless a message is being
            sent more than once. If a header is supplied, this can be set to
            None and the msg_type will be pulled from the header.

        content : dict or None
            The content of the message (ignored if msg_or_type is a message).
        header : dict or None
            The header dict for the message (ignored if msg_to_type is a message).
        parent : Message or dict or None
            The parent or parent header describing the parent of this message
            (ignored if msg_or_type is a message).
        ident : bytes or list of bytes
            The zmq.IDENTITY routing path.
        metadata : dict or None
            The metadata describing the message
        buffers : list or None
            The already-serialized buffers to be appended to the message.
        track : bool
            Whether to track.  Only for use with Sockets, because ZMQStream
            objects cannot track messages.
            

        Returns
        -------
        msg : dict
            The constructed message.
        """
        if not isinstance(stream, zmq.Socket):
            # ZMQStreams and dummy sockets do not support tracking.
            track = False

        if isinstance(msg_or_type, (Message, dict)):
            # We got a Message or message dict, not a msg_type so don't
            # build a new Message.
            msg = msg_or_type
        else:
            msg = self.msg(msg_or_type,
                           content=content,
                           parent=parent,
                           header=header,
                           metadata=metadata)
        if not os.getpid() == self.pid:
            io.rprint("WARNING: attempted to send message from fork")
            io.rprint(msg)
            return
        buffers = [] if buffers is None else buffers
        to_send = self.serialize(msg, ident)
        to_send.extend(buffers)
        longest = max([len(s) for s in to_send])
        copy = (longest < self.copy_threshold)

        if buffers and track and not copy:
            # only really track when we are doing zero-copy buffers
            tracker = stream.send_multipart(to_send, copy=False, track=True)
        else:
            # use dummy tracker, which will be done immediately
            tracker = DONE
            stream.send_multipart(to_send, copy=copy)

        if self.debug:
            pprint.pprint(msg)
            pprint.pprint(to_send)
            pprint.pprint(buffers)

        msg['tracker'] = tracker

        return msg

    def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
        """Send a raw message via ident path.

        This method is used to send a already serialized message.

        Parameters
        ----------
        stream : ZMQStream or Socket
            The ZMQ stream or socket to use for sending the message.
        msg_list : list
            The serialized list of messages to send. This only includes the
            [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
            the message.
        ident : ident or list
            A single ident or a list of idents to use in sending.
        """
        to_send = []
        if isinstance(ident, bytes):
            ident = [ident]
        if ident is not None:
            to_send.extend(ident)

        to_send.append(DELIM)
        to_send.append(self.sign(msg_list))
        to_send.extend(msg_list)
        stream.send_multipart(msg_list, flags, copy=copy)

    def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
        """Receive and unpack a message.

        Parameters
        ----------
        socket : ZMQStream or Socket
            The socket or stream to use in receiving.

        Returns
        -------
        [idents], msg
            [idents] is a list of idents and msg is a nested message dict of
            same format as self.msg returns.
        """
        if isinstance(socket, ZMQStream):
            socket = socket.socket
        try:
            msg_list = socket.recv_multipart(mode, copy=copy)
        except zmq.ZMQError as e:
            if e.errno == zmq.EAGAIN:
                # We can convert EAGAIN to None as we know in this case
                # recv_multipart won't return None.
                return None, None
            else:
                raise
        # split multipart message into identity list and message dict
        # invalid large messages can cause very expensive string comparisons
        idents, msg_list = self.feed_identities(msg_list, copy)
        try:
            return idents, self.unserialize(msg_list,
                                            content=content,
                                            copy=copy)
        except Exception as e:
            # TODO: handle it
            raise e

    def feed_identities(self, msg_list, copy=True):
        """Split the identities from the rest of the message.

        Feed until DELIM is reached, then return the prefix as idents and
        remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
        but that would be silly.

        Parameters
        ----------
        msg_list : a list of Message or bytes objects
            The message to be split.
        copy : bool
            flag determining whether the arguments are bytes or Messages

        Returns
        -------
        (idents, msg_list) : two lists
            idents will always be a list of bytes, each of which is a ZMQ
            identity. msg_list will be a list of bytes or zmq.Messages of the
            form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
            should be unpackable/unserializable via self.unserialize at this
            point.
        """
        if copy:
            idx = msg_list.index(DELIM)
            return msg_list[:idx], msg_list[idx + 1:]
        else:
            failed = True
            for idx, m in enumerate(msg_list):
                if m.bytes == DELIM:
                    failed = False
                    break
            if failed:
                raise ValueError("DELIM not in msg_list")
            idents, msg_list = msg_list[:idx], msg_list[idx + 1:]
            return [m.bytes for m in idents], msg_list

    def unserialize(self, msg_list, content=True, copy=True):
        """Unserialize a msg_list to a nested message dict.

        This is roughly the inverse of serialize. The serialize/unserialize
        methods work with full message lists, whereas pack/unpack work with
        the individual message parts in the message list.

        Parameters:
        -----------
        msg_list : list of bytes or Message objects
            The list of message parts of the form [HMAC,p_header,p_parent,
            p_metadata,p_content,buffer1,buffer2,...].
        content : bool (True)
            Whether to unpack the content dict (True), or leave it packed
            (False).
        copy : bool (True)
            Whether to return the bytes (True), or the non-copying Message
            object in each place (False).

        Returns
        -------
        msg : dict
            The nested message dict with top-level keys [header, parent_header,
            content, buffers].
        """
        minlen = 5
        message = {}
        if not copy:
            for i in range(minlen):
                msg_list[i] = msg_list[i].bytes
        if self.auth is not None:
            signature = msg_list[0]
            if not signature:
                raise ValueError("Unsigned Message")
            if signature in self.digest_history:
                raise ValueError("Duplicate Signature: %r" % signature)
            self.digest_history.add(signature)
            check = self.sign(msg_list[1:5])
            if not signature == check:
                raise ValueError("Invalid Signature: %r" % signature)
        if not len(msg_list) >= minlen:
            raise TypeError(
                "malformed message, must have at least %i elements" % minlen)
        header = self.unpack(msg_list[1])
        message['header'] = header
        message['msg_id'] = header['msg_id']
        message['msg_type'] = header['msg_type']
        message['parent_header'] = self.unpack(msg_list[2])
        message['metadata'] = self.unpack(msg_list[3])
        if content:
            message['content'] = self.unpack(msg_list[4])
        else:
            message['content'] = msg_list[4]

        message['buffers'] = msg_list[5:]
        return message
예제 #17
0
class NotebookApp(BaseIPythonApplication):

    name = 'ipython-notebook'

    description = """
        The IPython HTML Notebook.
        
        This launches a Tornado based HTML Notebook Server that serves up an
        HTML5/Javascript Notebook client.
    """
    examples = _examples

    classes = IPythonConsoleApp.classes + [
        MappingKernelManager, NotebookManager, FileNotebookManager,
        NotebookNotary
    ]
    flags = Dict(flags)
    aliases = Dict(aliases)

    subcommands = dict(list=(NbserverListApp,
                             NbserverListApp.description.splitlines()[0]), )

    kernel_argv = List(Unicode)

    def _log_level_default(self):
        return logging.INFO

    def _log_format_default(self):
        """override default log format to include time"""
        return u"%(asctime)s.%(msecs).03d [%(name)s]%(highlevel)s %(message)s"

    # create requested profiles by default, if they don't exist:
    auto_create = Bool(True)

    # file to be opened in the notebook server
    file_to_run = Unicode('', config=True)

    def _file_to_run_changed(self, name, old, new):
        path, base = os.path.split(new)
        if path:
            self.file_to_run = base
            self.notebook_dir = path

    # Network related information.

    ip = Unicode(config=True,
                 help="The IP address the notebook server will listen on.")

    def _ip_default(self):
        return localhost()

    def _ip_changed(self, name, old, new):
        if new == u'*': self.ip = u''

    port = Integer(8888,
                   config=True,
                   help="The port the notebook server will listen on.")
    port_retries = Integer(
        50,
        config=True,
        help=
        "The number of additional ports to try if the specified port is not available."
    )

    certfile = Unicode(
        u'',
        config=True,
        help="""The full path to an SSL/TLS certificate file.""")

    keyfile = Unicode(
        u'',
        config=True,
        help="""The full path to a private key file for usage with SSL/TLS.""")

    cookie_secret = Bytes(b'',
                          config=True,
                          help="""The random bytes used to secure cookies.
        By default this is a new random number every time you start the Notebook.
        Set it to a value in a config file to enable logins to persist across server sessions.
        
        Note: Cookie secrets should be kept private, do not share config files with
        cookie_secret stored in plaintext (you can read the value from a file).
        """)

    def _cookie_secret_default(self):
        return os.urandom(1024)

    password = Unicode(u'',
                       config=True,
                       help="""Hashed password to use for web authentication.

                      To generate, type in a python/IPython shell:

                        from IPython.lib import passwd; passwd()

                      The string should be of the form type:salt:hashed-password.
                      """)

    open_browser = Bool(True,
                        config=True,
                        help="""Whether to open in a browser after starting.
                        The specific browser used is platform dependent and
                        determined by the python standard library `webbrowser`
                        module, unless it is overridden using the --browser
                        (NotebookApp.browser) configuration option.
                        """)

    browser = Unicode(u'',
                      config=True,
                      help="""Specify what command to use to invoke a web
                      browser when opening the notebook. If not specified, the
                      default browser will be determined by the `webbrowser`
                      standard library module, which allows setting of the
                      BROWSER environment variable to override it.
                      """)

    webapp_settings = Dict(
        config=True,
        help="Supply overrides for the tornado.web.Application that the "
        "IPython notebook uses.")

    jinja_environment_options = Dict(
        config=True,
        help="Supply extra arguments that will be passed to Jinja environment."
    )

    enable_mathjax = Bool(
        True,
        config=True,
        help="""Whether to enable MathJax for typesetting math/TeX

        MathJax is the javascript library IPython uses to render math/LaTeX. It is
        very large, so you may want to disable it if you have a slow internet
        connection, or for offline use of the notebook.

        When disabled, equations etc. will appear as their untransformed TeX source.
        """)

    def _enable_mathjax_changed(self, name, old, new):
        """set mathjax url to empty if mathjax is disabled"""
        if not new:
            self.mathjax_url = u''

    base_url = Unicode('/',
                       config=True,
                       help='''The base URL for the notebook server.

                               Leading and trailing slashes can be omitted,
                               and will automatically be added.
                               ''')

    def _base_url_changed(self, name, old, new):
        if not new.startswith('/'):
            self.base_url = '/' + new
        elif not new.endswith('/'):
            self.base_url = new + '/'

    base_project_url = Unicode('/',
                               config=True,
                               help="""DEPRECATED use base_url""")

    def _base_project_url_changed(self, name, old, new):
        self.log.warn("base_project_url is deprecated, use base_url")
        self.base_url = new

    extra_static_paths = List(
        Unicode,
        config=True,
        help="""Extra paths to search for serving static files.
        
        This allows adding javascript/css to be available from the notebook server machine,
        or overriding individual files in the IPython""")

    def _extra_static_paths_default(self):
        return [os.path.join(self.profile_dir.location, 'static')]

    @property
    def static_file_path(self):
        """return extra paths + the default location"""
        return self.extra_static_paths + [DEFAULT_STATIC_FILES_PATH]

    nbextensions_path = List(
        Unicode,
        config=True,
        help=
        """paths for Javascript extensions. By default, this is just IPYTHONDIR/nbextensions"""
    )

    def _nbextensions_path_default(self):
        return [os.path.join(get_ipython_dir(), 'nbextensions')]

    mathjax_url = Unicode("", config=True, help="""The url for MathJax.js.""")

    def _mathjax_url_default(self):
        if not self.enable_mathjax:
            return u''
        static_url_prefix = self.webapp_settings.get(
            "static_url_prefix", url_path_join(self.base_url, "static"))

        # try local mathjax, either in nbextensions/mathjax or static/mathjax
        for (url_prefix, search_path) in [
            (url_path_join(self.base_url,
                           "nbextensions"), self.nbextensions_path),
            (static_url_prefix, self.static_file_path),
        ]:
            self.log.debug("searching for local mathjax in %s", search_path)
            try:
                mathjax = filefind(os.path.join('mathjax', 'MathJax.js'),
                                   search_path)
            except IOError:
                continue
            else:
                url = url_path_join(url_prefix, u"mathjax/MathJax.js")
                self.log.info("Serving local MathJax from %s at %s", mathjax,
                              url)
                return url

        # no local mathjax, serve from CDN
        if self.certfile:
            # HTTPS: load from Rackspace CDN, because SSL certificate requires it
            host = u"https://c328740.ssl.cf1.rackcdn.com"
        else:
            host = u"http://cdn.mathjax.org"

        url = host + u"/mathjax/latest/MathJax.js"
        self.log.info("Using MathJax from CDN: %s", url)
        return url

    def _mathjax_url_changed(self, name, old, new):
        if new and not self.enable_mathjax:
            # enable_mathjax=False overrides mathjax_url
            self.mathjax_url = u''
        else:
            self.log.info("Using MathJax: %s", new)

    notebook_manager_class = DottedObjectName(
        'IPython.html.services.notebooks.filenbmanager.FileNotebookManager',
        config=True,
        help='The notebook manager class to use.')

    trust_xheaders = Bool(
        False,
        config=True,
        help=
        ("Whether to trust or not X-Scheme/X-Forwarded-Proto and X-Real-Ip/X-Forwarded-For headers"
         "sent by the upstream reverse proxy. Necessary if the proxy handles SSL"
         ))

    info_file = Unicode()

    def _info_file_default(self):
        info_file = "nbserver-%s.json" % os.getpid()
        return os.path.join(self.profile_dir.security_dir, info_file)

    notebook_dir = Unicode(
        py3compat.getcwd(),
        config=True,
        help="The directory to use for notebooks and kernels.")

    def _notebook_dir_changed(self, name, old, new):
        """Do a bit of validation of the notebook dir."""
        if not os.path.isabs(new):
            # If we receive a non-absolute path, make it absolute.
            self.notebook_dir = os.path.abspath(new)
            return
        if not os.path.isdir(new):
            raise TraitError("No such notebook dir: %r" % new)

        # setting App.notebook_dir implies setting notebook and kernel dirs as well
        self.config.FileNotebookManager.notebook_dir = new
        self.config.MappingKernelManager.root_dir = new

    def parse_command_line(self, argv=None):
        super(NotebookApp, self).parse_command_line(argv)

        if self.extra_args:
            arg0 = self.extra_args[0]
            f = os.path.abspath(arg0)
            self.argv.remove(arg0)
            if not os.path.exists(f):
                self.log.critical("No such file or directory: %s", f)
                self.exit(1)

            # Use config here, to ensure that it takes higher priority than
            # anything that comes from the profile.
            if os.path.isdir(f):
                self.config.NotebookApp.notebook_dir = f
            elif os.path.isfile(f):
                self.config.NotebookApp.file_to_run = f

    def init_kernel_argv(self):
        """construct the kernel arguments"""
        # Scrub frontend-specific flags
        self.kernel_argv = swallow_argv(self.argv, notebook_aliases,
                                        notebook_flags)
        if any(arg.startswith(u'--pylab') for arg in self.kernel_argv):
            self.log.warn('\n    '.join([
                "Starting all kernels in pylab mode is not recommended,",
                "and will be disabled in a future release.",
                "Please use the %matplotlib magic to enable matplotlib instead.",
                "pylab implies many imports, which can have confusing side effects",
                "and harm the reproducibility of your notebooks.",
            ]))
        # Kernel should inherit default config file from frontend
        self.kernel_argv.append("--IPKernelApp.parent_appname='%s'" %
                                self.name)
        # Kernel should get *absolute* path to profile directory
        self.kernel_argv.extend(["--profile-dir", self.profile_dir.location])

    def init_configurables(self):
        # force Session default to be secure
        default_secure(self.config)
        self.kernel_manager = MappingKernelManager(
            parent=self,
            log=self.log,
            kernel_argv=self.kernel_argv,
            connection_dir=self.profile_dir.security_dir,
        )
        kls = import_item(self.notebook_manager_class)
        self.notebook_manager = kls(parent=self, log=self.log)
        self.session_manager = SessionManager(parent=self, log=self.log)
        self.cluster_manager = ClusterManager(parent=self, log=self.log)
        self.cluster_manager.update_profiles()

    def init_logging(self):
        # This prevents double log messages because tornado use a root logger that
        # self.log is a child of. The logging module dipatches log messages to a log
        # and all of its ancenstors until propagate is set to False.
        self.log.propagate = False

        # hook up tornado 3's loggers to our app handlers
        for name in ('access', 'application', 'general'):
            logger = logging.getLogger('tornado.%s' % name)
            logger.parent = self.log
            logger.setLevel(self.log.level)

    def init_webapp(self):
        """initialize tornado webapp and httpserver"""
        self.web_app = NotebookWebApplication(
            self, self.kernel_manager, self.notebook_manager,
            self.cluster_manager, self.session_manager, self.log,
            self.base_url, self.webapp_settings,
            self.jinja_environment_options)
        if self.certfile:
            ssl_options = dict(certfile=self.certfile)
            if self.keyfile:
                ssl_options['keyfile'] = self.keyfile
        else:
            ssl_options = None
        self.web_app.password = self.password
        self.http_server = httpserver.HTTPServer(self.web_app,
                                                 ssl_options=ssl_options,
                                                 xheaders=self.trust_xheaders)
        if not self.ip:
            warning = "WARNING: The notebook server is listening on all IP addresses"
            if ssl_options is None:
                self.log.critical(warning + " and not using encryption. This "
                                  "is not recommended.")
            if not self.password:
                self.log.critical(
                    warning + " and not using authentication. "
                    "This is highly insecure and not recommended.")
        success = None
        for port in random_ports(self.port, self.port_retries + 1):
            try:
                self.http_server.listen(port, self.ip)
            except socket.error as e:
                if e.errno == errno.EADDRINUSE:
                    self.log.info(
                        'The port %i is already in use, trying another random port.'
                        % port)
                    continue
                elif e.errno in (errno.EACCES,
                                 getattr(errno, 'WSAEACCES', errno.EACCES)):
                    self.log.warn("Permission to listen on port %i denied" %
                                  port)
                    continue
                else:
                    raise
            else:
                self.port = port
                success = True
                break
        if not success:
            self.log.critical(
                'ERROR: the notebook server could not be started because '
                'no available port could be found.')
            self.exit(1)

    @property
    def display_url(self):
        ip = self.ip if self.ip else '[all ip addresses on your system]'
        return self._url(ip)

    @property
    def connection_url(self):
        ip = self.ip if self.ip else localhost()
        return self._url(ip)

    def _url(self, ip):
        proto = 'https' if self.certfile else 'http'
        return "%s://%s:%i%s" % (proto, ip, self.port, self.base_url)

    def init_signal(self):
        if not sys.platform.startswith('win'):
            signal.signal(signal.SIGINT, self._handle_sigint)
        signal.signal(signal.SIGTERM, self._signal_stop)
        if hasattr(signal, 'SIGUSR1'):
            # Windows doesn't support SIGUSR1
            signal.signal(signal.SIGUSR1, self._signal_info)
        if hasattr(signal, 'SIGINFO'):
            # only on BSD-based systems
            signal.signal(signal.SIGINFO, self._signal_info)

    def _handle_sigint(self, sig, frame):
        """SIGINT handler spawns confirmation dialog"""
        # register more forceful signal handler for ^C^C case
        signal.signal(signal.SIGINT, self._signal_stop)
        # request confirmation dialog in bg thread, to avoid
        # blocking the App
        thread = threading.Thread(target=self._confirm_exit)
        thread.daemon = True
        thread.start()

    def _restore_sigint_handler(self):
        """callback for restoring original SIGINT handler"""
        signal.signal(signal.SIGINT, self._handle_sigint)

    def _confirm_exit(self):
        """confirm shutdown on ^C
        
        A second ^C, or answering 'y' within 5s will cause shutdown,
        otherwise original SIGINT handler will be restored.
        
        This doesn't work on Windows.
        """
        # FIXME: remove this delay when pyzmq dependency is >= 2.1.11
        time.sleep(0.1)
        info = self.log.info
        info('interrupted')
        print(self.notebook_info())
        sys.stdout.write("Shutdown this notebook server (y/[n])? ")
        sys.stdout.flush()
        r, w, x = select.select([sys.stdin], [], [], 5)
        if r:
            line = sys.stdin.readline()
            if line.lower().startswith('y'):
                self.log.critical("Shutdown confirmed")
                ioloop.IOLoop.instance().stop()
                return
        else:
            print("No answer for 5s:", end=' ')
        print("resuming operation...")
        # no answer, or answer is no:
        # set it back to original SIGINT handler
        # use IOLoop.add_callback because signal.signal must be called
        # from main thread
        ioloop.IOLoop.instance().add_callback(self._restore_sigint_handler)

    def _signal_stop(self, sig, frame):
        self.log.critical("received signal %s, stopping", sig)
        ioloop.IOLoop.instance().stop()

    def _signal_info(self, sig, frame):
        print(self.notebook_info())

    def init_components(self):
        """Check the components submodule, and warn if it's unclean"""
        status = submodule.check_submodule_status()
        if status == 'missing':
            self.log.warn(
                "components submodule missing, running `git submodule update`")
            submodule.update_submodules(submodule.ipython_parent())
        elif status == 'unclean':
            self.log.warn(
                "components submodule unclean, you may see 404s on static/components"
            )
            self.log.warn(
                "run `setup.py submodule` or `git submodule update` to update")

    @catch_config_error
    def initialize(self, argv=None):
        super(NotebookApp, self).initialize(argv)
        self.init_logging()
        self.init_kernel_argv()
        self.init_configurables()
        self.init_components()
        self.init_webapp()
        self.init_signal()

    def cleanup_kernels(self):
        """Shutdown all kernels.
        
        The kernels will shutdown themselves when this process no longer exists,
        but explicit shutdown allows the KernelManagers to cleanup the connection files.
        """
        self.log.info('Shutting down kernels')
        self.kernel_manager.shutdown_all()

    def notebook_info(self):
        "Return the current working directory and the server url information"
        info = self.notebook_manager.info_string() + "\n"
        info += "%d active kernels \n" % len(self.kernel_manager._kernels)
        return info + "The IPython Notebook is running at: %s" % self.display_url

    def server_info(self):
        """Return a JSONable dict of information about this server."""
        return {
            'url': self.connection_url,
            'hostname': self.ip if self.ip else 'localhost',
            'port': self.port,
            'secure': bool(self.certfile),
            'base_url': self.base_url,
            'notebook_dir': os.path.abspath(self.notebook_dir),
        }

    def write_server_info_file(self):
        """Write the result of server_info() to the JSON file info_file."""
        with open(self.info_file, 'w') as f:
            json.dump(self.server_info(), f, indent=2)

    def remove_server_info_file(self):
        """Remove the nbserver-<pid>.json file created for this server.
        
        Ignores the error raised when the file has already been removed.
        """
        try:
            os.unlink(self.info_file)
        except OSError as e:
            if e.errno != errno.ENOENT:
                raise

    def start(self):
        """ Start the IPython Notebook server app, after initialization
        
        This method takes no arguments so all configuration and initialization
        must be done prior to calling this method."""
        if self.subapp is not None:
            return self.subapp.start()

        info = self.log.info
        for line in self.notebook_info().split("\n"):
            info(line)
        info(
            "Use Control-C to stop this server and shut down all kernels (twice to skip confirmation)."
        )

        self.write_server_info_file()

        if self.open_browser or self.file_to_run:
            try:
                browser = webbrowser.get(self.browser or None)
            except webbrowser.Error as e:
                self.log.warn('No web browser found: %s.' % e)
                browser = None

            if self.file_to_run:
                fullpath = os.path.join(self.notebook_dir, self.file_to_run)
                if not os.path.exists(fullpath):
                    self.log.critical("%s does not exist" % fullpath)
                    self.exit(1)

                uri = url_path_join('notebooks', self.file_to_run)
            else:
                uri = 'tree'
            if browser:
                b = lambda: browser.open(
                    url_path_join(self.connection_url, uri), new=2)
                threading.Thread(target=b).start()
        try:
            ioloop.IOLoop.instance().start()
        except KeyboardInterrupt:
            info("Interrupted...")
        finally:
            self.cleanup_kernels()
            self.remove_server_info_file()
예제 #18
0
class MappingKernelManager(MultiKernelManager):
    """A KernelManager that handles notebok mapping and HTTP error handling"""

    kernel_argv = List(Unicode)
    
    time_to_dead = Float(3.0, config=True, help="""Kernel heartbeat interval in seconds.""")
    first_beat = Float(5.0, config=True, help="Delay (in seconds) before sending first heartbeat.")
    
    max_msg_size = Integer(65536, config=True, help="""
        The max raw message size accepted from the browser
        over a WebSocket connection.
    """)

    _notebook_mapping = Dict()

    #-------------------------------------------------------------------------
    # Methods for managing kernels and sessions
    #-------------------------------------------------------------------------

    def kernel_for_notebook(self, notebook_id):
        """Return the kernel_id for a notebook_id or None."""
        return self._notebook_mapping.get(notebook_id)

    def set_kernel_for_notebook(self, notebook_id, kernel_id):
        """Associate a notebook with a kernel."""
        if notebook_id is not None:
            self._notebook_mapping[notebook_id] = kernel_id

    def notebook_for_kernel(self, kernel_id):
        """Return the notebook_id for a kernel_id or None."""
        notebook_ids = [k for k, v in self._notebook_mapping.iteritems() if v == kernel_id]
        if len(notebook_ids) == 1:
            return notebook_ids[0]
        else:
            return None

    def delete_mapping_for_kernel(self, kernel_id):
        """Remove the kernel/notebook mapping for kernel_id."""
        notebook_id = self.notebook_for_kernel(kernel_id)
        if notebook_id is not None:
            del self._notebook_mapping[notebook_id]

    def start_kernel(self, notebook_id=None):
        """Start a kernel for a notebok an return its kernel_id.

        Parameters
        ----------
        notebook_id : uuid
            The uuid of the notebook to associate the new kernel with. If this
            is not None, this kernel will be persistent whenever the notebook
            requests a kernel.
        """
        kernel_id = self.kernel_for_notebook(notebook_id)
        if kernel_id is None:
            kwargs = dict()
            kwargs['extra_arguments'] = self.kernel_argv
            kernel_id = super(MappingKernelManager, self).start_kernel(**kwargs)
            self.set_kernel_for_notebook(notebook_id, kernel_id)
            self.log.info("Kernel started: %s" % kernel_id)
            self.log.debug("Kernel args: %r" % kwargs)
        else:
            self.log.info("Using existing kernel: %s" % kernel_id)
        return kernel_id

    def kill_kernel(self, kernel_id):
        """Kill a kernel and remove its notebook association."""
        self._check_kernel_id(kernel_id)
        super(MappingKernelManager, self).kill_kernel(kernel_id)
        self.delete_mapping_for_kernel(kernel_id)
        self.log.info("Kernel killed: %s" % kernel_id)

    def interrupt_kernel(self, kernel_id):
        """Interrupt a kernel."""
        self._check_kernel_id(kernel_id)
        super(MappingKernelManager, self).interrupt_kernel(kernel_id)
        self.log.info("Kernel interrupted: %s" % kernel_id)

    def restart_kernel(self, kernel_id):
        """Restart a kernel while keeping clients connected."""
        self._check_kernel_id(kernel_id)
        km = self.get_kernel(kernel_id)
        km.restart_kernel(now=True)
        self.log.info("Kernel restarted: %s" % kernel_id)
        return kernel_id
        
        # the following remains, in case the KM restart machinery is
        # somehow unacceptable
        # Get the notebook_id to preserve the kernel/notebook association.
        notebook_id = self.notebook_for_kernel(kernel_id)
        # Create the new kernel first so we can move the clients over.
        new_kernel_id = self.start_kernel()
        # Now kill the old kernel.
        self.kill_kernel(kernel_id)
        # Now save the new kernel/notebook association. We have to save it
        # after the old kernel is killed as that will delete the mapping.
        self.set_kernel_for_notebook(notebook_id, new_kernel_id)
        self.log.info("Kernel restarted: %s" % new_kernel_id)
        return new_kernel_id

    def create_iopub_stream(self, kernel_id):
        """Create a new iopub stream."""
        self._check_kernel_id(kernel_id)
        return super(MappingKernelManager, self).create_iopub_stream(kernel_id)

    def create_shell_stream(self, kernel_id):
        """Create a new shell stream."""
        self._check_kernel_id(kernel_id)
        return super(MappingKernelManager, self).create_shell_stream(kernel_id)

    def create_hb_stream(self, kernel_id):
        """Create a new hb stream."""
        self._check_kernel_id(kernel_id)
        return super(MappingKernelManager, self).create_hb_stream(kernel_id)

    def _check_kernel_id(self, kernel_id):
        """Check a that a kernel_id exists and raise 404 if not."""
        if kernel_id not in self:
            raise web.HTTPError(404, u'Kernel does not exist: %s' % kernel_id)
예제 #19
0
class IPKernelApp(BaseIPythonApplication, InteractiveShellApp,
                  ConnectionFileMixin):
    name = 'ipython-kernel'
    aliases = Dict(kernel_aliases)
    flags = Dict(kernel_flags)
    classes = [IPythonKernel, ZMQInteractiveShell, ProfileDir, Session]
    # the kernel class, as an importstring
    kernel_class = Type('ipython_kernel.ipkernel.IPythonKernel',
                        config=True,
                        klass='ipython_kernel.kernelbase.Kernel',
                        help="""The Kernel subclass to be used.

    This should allow easy re-use of the IPKernelApp entry point
    to configure and launch kernels other than IPython's own.
    """)
    kernel = Any()
    poller = Any(
    )  # don't restrict this even though current pollers are all Threads
    heartbeat = Instance(Heartbeat, allow_none=True)
    ports = Dict()

    # connection info:

    @property
    def abs_connection_file(self):
        if os.path.basename(self.connection_file) == self.connection_file:
            return os.path.join(self.profile_dir.security_dir,
                                self.connection_file)
        else:
            return self.connection_file

    # streams, etc.
    no_stdout = Bool(False,
                     config=True,
                     help="redirect stdout to the null device")
    no_stderr = Bool(False,
                     config=True,
                     help="redirect stderr to the null device")
    outstream_class = DottedObjectName(
        'ipython_kernel.iostream.OutStream',
        config=True,
        help="The importstring for the OutStream factory")
    displayhook_class = DottedObjectName(
        'ipython_kernel.displayhook.ZMQDisplayHook',
        config=True,
        help="The importstring for the DisplayHook factory")

    # polling
    parent_handle = Integer(
        int(os.environ.get('JPY_PARENT_PID') or 0),
        config=True,
        help="""kill this process if its parent dies.  On Windows, the argument
        specifies the HANDLE of the parent process, otherwise it is simply boolean.
        """)
    interrupt = Integer(int(os.environ.get('JPY_INTERRUPT_EVENT') or 0),
                        config=True,
                        help="""ONLY USED ON WINDOWS
        Interrupt this process when the parent is signaled.
        """)

    def init_crash_handler(self):
        # Install minimal exception handling
        sys.excepthook = FormattedTB(mode='Verbose',
                                     color_scheme='NoColor',
                                     ostream=sys.__stdout__)

    def init_poller(self):
        if sys.platform == 'win32':
            if self.interrupt or self.parent_handle:
                self.poller = ParentPollerWindows(self.interrupt,
                                                  self.parent_handle)
        elif self.parent_handle:
            self.poller = ParentPollerUnix()

    def _bind_socket(self, s, port):
        iface = '%s://%s' % (self.transport, self.ip)
        if self.transport == 'tcp':
            if port <= 0:
                port = s.bind_to_random_port(iface)
            else:
                s.bind("tcp://%s:%i" % (self.ip, port))
        elif self.transport == 'ipc':
            if port <= 0:
                port = 1
                path = "%s-%i" % (self.ip, port)
                while os.path.exists(path):
                    port = port + 1
                    path = "%s-%i" % (self.ip, port)
            else:
                path = "%s-%i" % (self.ip, port)
            s.bind("ipc://%s" % path)
        return port

    def write_connection_file(self):
        """write connection info to JSON file"""
        cf = self.abs_connection_file
        self.log.debug("Writing connection file: %s", cf)
        write_connection_file(cf,
                              ip=self.ip,
                              key=self.session.key,
                              transport=self.transport,
                              shell_port=self.shell_port,
                              stdin_port=self.stdin_port,
                              hb_port=self.hb_port,
                              iopub_port=self.iopub_port,
                              control_port=self.control_port)

    def cleanup_connection_file(self):
        cf = self.abs_connection_file
        self.log.debug("Cleaning up connection file: %s", cf)
        try:
            os.remove(cf)
        except (IOError, OSError):
            pass

        self.cleanup_ipc_files()

    def init_connection_file(self):
        if not self.connection_file:
            self.connection_file = "kernel-%s.json" % os.getpid()
        try:
            self.connection_file = filefind(
                self.connection_file, ['.', self.profile_dir.security_dir])
        except IOError:
            self.log.debug("Connection file not found: %s",
                           self.connection_file)
            # This means I own it, so I will clean it up:
            atexit.register(self.cleanup_connection_file)
            return
        try:
            self.load_connection_file()
        except Exception:
            self.log.error("Failed to load connection file: %r",
                           self.connection_file,
                           exc_info=True)
            self.exit(1)

    def init_sockets(self):
        # Create a context, a session, and the kernel sockets.
        self.log.info("Starting the kernel at pid: %i", os.getpid())
        context = zmq.Context.instance()
        # Uncomment this to try closing the context.
        # atexit.register(context.term)

        self.shell_socket = context.socket(zmq.ROUTER)
        self.shell_socket.linger = 1000
        self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
        self.log.debug("shell ROUTER Channel on port: %i" % self.shell_port)

        self.iopub_socket = context.socket(zmq.PUB)
        self.iopub_socket.linger = 1000
        self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port)
        self.log.debug("iopub PUB Channel on port: %i" % self.iopub_port)

        self.stdin_socket = context.socket(zmq.ROUTER)
        self.stdin_socket.linger = 1000
        self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port)
        self.log.debug("stdin ROUTER Channel on port: %i" % self.stdin_port)

        self.control_socket = context.socket(zmq.ROUTER)
        self.control_socket.linger = 1000
        self.control_port = self._bind_socket(self.control_socket,
                                              self.control_port)
        self.log.debug("control ROUTER Channel on port: %i" %
                       self.control_port)

    def init_heartbeat(self):
        """start the heart beating"""
        # heartbeat doesn't share context, because it mustn't be blocked
        # by the GIL, which is accessed by libzmq when freeing zero-copy messages
        hb_ctx = zmq.Context()
        self.heartbeat = Heartbeat(hb_ctx,
                                   (self.transport, self.ip, self.hb_port))
        self.hb_port = self.heartbeat.port
        self.log.debug("Heartbeat REP Channel on port: %i" % self.hb_port)
        self.heartbeat.start()

    def log_connection_info(self):
        """display connection info, and store ports"""
        basename = os.path.basename(self.connection_file)
        if basename == self.connection_file or \
            os.path.dirname(self.connection_file) == self.profile_dir.security_dir:
            # use shortname
            tail = basename
            if self.profile != 'default':
                tail += " --profile %s" % self.profile
        else:
            tail = self.connection_file
        lines = [
            "To connect another client to this kernel, use:",
            "    --existing %s" % tail,
        ]
        # log connection info
        # info-level, so often not shown.
        # frontends should use the %connect_info magic
        # to see the connection info
        for line in lines:
            self.log.info(line)
        # also raw print to the terminal if no parent_handle (`ipython kernel`)
        if not self.parent_handle:
            io.rprint(_ctrl_c_message)
            for line in lines:
                io.rprint(line)

        self.ports = dict(shell=self.shell_port,
                          iopub=self.iopub_port,
                          stdin=self.stdin_port,
                          hb=self.hb_port,
                          control=self.control_port)

    def init_blackhole(self):
        """redirects stdout/stderr to devnull if necessary"""
        if self.no_stdout or self.no_stderr:
            blackhole = open(os.devnull, 'w')
            if self.no_stdout:
                sys.stdout = sys.__stdout__ = blackhole
            if self.no_stderr:
                sys.stderr = sys.__stderr__ = blackhole

    def init_io(self):
        """Redirect input streams and set a display hook."""
        if self.outstream_class:
            outstream_factory = import_item(str(self.outstream_class))
            sys.stdout = outstream_factory(self.session, self.iopub_socket,
                                           u'stdout')
            sys.stderr = outstream_factory(self.session, self.iopub_socket,
                                           u'stderr')
        if self.displayhook_class:
            displayhook_factory = import_item(str(self.displayhook_class))
            sys.displayhook = displayhook_factory(self.session,
                                                  self.iopub_socket)

    def init_signal(self):
        signal.signal(signal.SIGINT, signal.SIG_IGN)

    def init_kernel(self):
        """Create the Kernel object itself"""
        shell_stream = ZMQStream(self.shell_socket)
        control_stream = ZMQStream(self.control_socket)

        kernel_factory = self.kernel_class.instance

        kernel = kernel_factory(
            parent=self,
            session=self.session,
            shell_streams=[shell_stream, control_stream],
            iopub_socket=self.iopub_socket,
            stdin_socket=self.stdin_socket,
            log=self.log,
            profile_dir=self.profile_dir,
            user_ns=self.user_ns,
        )
        kernel.record_ports(self.ports)
        self.kernel = kernel

    def init_gui_pylab(self):
        """Enable GUI event loop integration, taking pylab into account."""

        # Provide a wrapper for :meth:`InteractiveShellApp.init_gui_pylab`
        # to ensure that any exception is printed straight to stderr.
        # Normally _showtraceback associates the reply with an execution,
        # which means frontends will never draw it, as this exception
        # is not associated with any execute request.

        shell = self.shell
        _showtraceback = shell._showtraceback
        try:
            # replace error-sending traceback with stderr
            def print_tb(etype, evalue, stb):
                print("GUI event loop or pylab initialization failed",
                      file=io.stderr)
                print(shell.InteractiveTB.stb2text(stb), file=io.stderr)

            shell._showtraceback = print_tb
            InteractiveShellApp.init_gui_pylab(self)
        finally:
            shell._showtraceback = _showtraceback

    def init_shell(self):
        self.shell = getattr(self.kernel, 'shell', None)
        if self.shell:
            self.shell.configurables.append(self)

    @catch_config_error
    def initialize(self, argv=None):
        super(IPKernelApp, self).initialize(argv)
        self.init_blackhole()
        self.init_connection_file()
        self.init_poller()
        self.init_sockets()
        self.init_heartbeat()
        # writing/displaying connection info must be *after* init_sockets/heartbeat
        self.log_connection_info()
        self.write_connection_file()
        self.init_io()
        self.init_signal()
        self.init_kernel()
        # shell init steps
        self.init_path()
        self.init_shell()
        if self.shell:
            self.init_gui_pylab()
            self.init_extensions()
            self.init_code()
        # flush stdout/stderr, so that anything written to these streams during
        # initialization do not get associated with the first execution request
        sys.stdout.flush()
        sys.stderr.flush()

    def start(self):
        if self.poller is not None:
            self.poller.start()
        self.kernel.start()
        try:
            ioloop.IOLoop.instance().start()
        except KeyboardInterrupt:
            pass
예제 #20
0
class HeartMonitor(LoggingConfigurable):
    """A basic HeartMonitor class
    pingstream: a PUB stream
    pongstream: an ROUTER stream
    period: the period of the heartbeat in milliseconds"""

    period = Integer(
        3000,
        config=True,
        help='The frequency at which the Hub pings the engines for heartbeats '
        '(in ms)',
    )

    pingstream = Instance('zmq.eventloop.zmqstream.ZMQStream')
    pongstream = Instance('zmq.eventloop.zmqstream.ZMQStream')
    loop = Instance('zmq.eventloop.ioloop.IOLoop')

    def _loop_default(self):
        return ioloop.IOLoop.instance()

    # not settable:
    hearts = Set()
    responses = Set()
    on_probation = Set()
    last_ping = CFloat(0)
    _new_handlers = Set()
    _failure_handlers = Set()
    lifetime = CFloat(0)
    tic = CFloat(0)

    def __init__(self, **kwargs):
        super(HeartMonitor, self).__init__(**kwargs)

        self.pongstream.on_recv(self.handle_pong)

    def start(self):
        self.tic = time.time()
        self.caller = ioloop.PeriodicCallback(self.beat, self.period,
                                              self.loop)
        self.caller.start()

    def add_new_heart_handler(self, handler):
        """add a new handler for new hearts"""
        self.log.debug("heartbeat::new_heart_handler: %s", handler)
        self._new_handlers.add(handler)

    def add_heart_failure_handler(self, handler):
        """add a new handler for heart failure"""
        self.log.debug("heartbeat::new heart failure handler: %s", handler)
        self._failure_handlers.add(handler)

    def beat(self):
        self.pongstream.flush()
        self.last_ping = self.lifetime

        toc = time.time()
        self.lifetime += toc - self.tic
        self.tic = toc
        self.log.debug("heartbeat::sending %s", self.lifetime)
        goodhearts = self.hearts.intersection(self.responses)
        missed_beats = self.hearts.difference(goodhearts)
        heartfailures = self.on_probation.intersection(missed_beats)
        newhearts = self.responses.difference(goodhearts)
        map(self.handle_new_heart, newhearts)
        map(self.handle_heart_failure, heartfailures)
        self.on_probation = missed_beats.intersection(self.hearts)
        self.responses = set()
        #print self.on_probation, self.hearts
        # self.log.debug("heartbeat::beat %.3f, %i beating hearts", self.lifetime, len(self.hearts))
        self.pingstream.send(str_to_bytes(str(self.lifetime)))
        # flush stream to force immediate socket send
        self.pingstream.flush()

    def handle_new_heart(self, heart):
        if self._new_handlers:
            for handler in self._new_handlers:
                handler(heart)
        else:
            self.log.info("heartbeat::yay, got new heart %s!", heart)
        self.hearts.add(heart)

    def handle_heart_failure(self, heart):
        if self._failure_handlers:
            for handler in self._failure_handlers:
                try:
                    handler(heart)
                except Exception as e:
                    self.log.error("heartbeat::Bad Handler! %s",
                                   handler,
                                   exc_info=True)
                    pass
        else:
            self.log.info("heartbeat::Heart %s failed :(", heart)
        self.hearts.remove(heart)

    @log_errors
    def handle_pong(self, msg):
        "a heart just beat"
        current = str_to_bytes(str(self.lifetime))
        last = str_to_bytes(str(self.last_ping))
        if msg[1] == current:
            delta = time.time() - self.tic
            # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
            self.responses.add(msg[0])
        elif msg[1] == last:
            delta = time.time() - self.tic + (self.lifetime - self.last_ping)
            self.log.warn(
                "heartbeat::heart %r missed a beat, and took %.2f ms to respond",
                msg[0], 1000 * delta)
            self.responses.add(msg[0])
        else:
            self.log.warn(
                "heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)",
                msg[1], self.lifetime)
class IPClusterEngines(BaseParallelApplication):

    name = 'ipcluster'
    description = engines_help
    examples = _engines_examples
    usage = None
    config_file_name = Unicode(default_config_file_name)
    default_log_level = logging.INFO
    classes = List()

    def _classes_default(self):
        from IPython.parallel.apps import launcher
        launchers = launcher.all_launchers
        eslaunchers = [l for l in launchers if 'EngineSet' in l.__name__]
        return [ProfileDir] + eslaunchers

    n = Integer(
        num_cpus(),
        config=True,
        help=
        """The number of engines to start. The default is to use one for each
        CPU on your machine""")

    engine_launcher = Any(config=True,
                          help="Deprecated, use engine_launcher_class")

    def _engine_launcher_changed(self, name, old, new):
        if isinstance(new, str):
            self.log.warn(
                "WARNING: %s.engine_launcher is deprecated as of 0.12,"
                " use engine_launcher_class" % self.__class__.__name__)
            self.engine_launcher_class = new

    engine_launcher_class = DottedObjectName(
        'LocalEngineSetLauncher',
        config=True,
        help="""The class for launching a set of Engines. Change this value
        to use various batch systems to launch your engines, such as PBS,SGE,MPI,etc.
        Each launcher class has its own set of configuration options, for making sure
        it will work in your environment.

        You can also write your own launcher, and specify it's absolute import path,
        as in 'mymodule.launcher.FTLEnginesLauncher`.

        IPython's bundled examples include:

            Local : start engines locally as subprocesses [default]
            MPI : use mpiexec to launch engines in an MPI environment
            PBS : use PBS (qsub) to submit engines to a batch queue
            SGE : use SGE (qsub) to submit engines to a batch queue
            LSF : use LSF (bsub) to submit engines to a batch queue
            SSH : use SSH to start the controller
                        Note that SSH does *not* move the connection files
                        around, so you will likely have to do this manually
                        unless the machines are on a shared file system.
            WindowsHPC : use Windows HPC

        If you are using one of IPython's builtin launchers, you can specify just the
        prefix, e.g:

            c.IPClusterEngines.engine_launcher_class = 'SSH'

        or:

            ipcluster start --engines=MPI

        """)
    daemonize = Bool(
        False,
        config=True,
        help="""Daemonize the ipcluster program. This implies --log-to-file.
        Not available on Windows.
        """)

    def _daemonize_changed(self, name, old, new):
        if new:
            self.log_to_file = True

    early_shutdown = Integer(30, config=True, help="The timeout (in seconds)")
    _stopping = False

    aliases = Dict(engine_aliases)
    flags = Dict(engine_flags)

    @catch_config_error
    def initialize(self, argv=None):
        super(IPClusterEngines, self).initialize(argv)
        self.init_signal()
        self.init_launchers()

    def init_launchers(self):
        self.engine_launcher = self.build_launcher(self.engine_launcher_class,
                                                   'EngineSet')

    def init_signal(self):
        # Setup signals
        signal.signal(signal.SIGINT, self.sigint_handler)

    def build_launcher(self, clsname, kind=None):
        """import and instantiate a Launcher based on importstring"""
        try:
            klass = find_launcher_class(clsname, kind)
        except (ImportError, KeyError):
            self.log.fatal("Could not import launcher class: %r" % clsname)
            self.exit(1)

        launcher = klass(
            work_dir='.',
            config=self.config,
            log=self.log,
            profile_dir=self.profile_dir.location,
            cluster_id=self.cluster_id,
        )
        return launcher

    def engines_started_ok(self):
        self.log.info("Engines appear to have started successfully")
        self.early_shutdown = 0

    def start_engines(self):
        # Some EngineSetLaunchers ignore `n` and use their own engine count, such as SSH:
        n = getattr(self.engine_launcher, 'engine_count', self.n)
        self.log.info("Starting %s Engines with %s", n,
                      self.engine_launcher_class)
        self.engine_launcher.start(self.n)
        self.engine_launcher.on_stop(self.engines_stopped_early)
        if self.early_shutdown:
            ioloop.DelayedCallback(self.engines_started_ok,
                                   self.early_shutdown * 1000,
                                   self.loop).start()

    def engines_stopped_early(self, r):
        if self.early_shutdown and not self._stopping:
            self.log.error("""
            Engines shutdown early, they probably failed to connect.
            
            Check the engine log files for output.
            
            If your controller and engines are not on the same machine, you probably
            have to instruct the controller to listen on an interface other than localhost.
            
            You can set this by adding "--ip='*'" to your ControllerLauncher.controller_args.
            
            Be sure to read our security docs before instructing your controller to listen on
            a public interface.
            """)
            self.stop_launchers()

        return self.engines_stopped(r)

    def engines_stopped(self, r):
        return self.loop.stop()

    def stop_engines(self):
        if self.engine_launcher.running:
            self.log.info("Stopping Engines...")
            d = self.engine_launcher.stop()
            return d
        else:
            return None

    def stop_launchers(self, r=None):
        if not self._stopping:
            self._stopping = True
            self.log.error("IPython cluster: stopping")
            self.stop_engines()
            # Wait a few seconds to let things shut down.
            dc = ioloop.DelayedCallback(self.loop.stop, 3000, self.loop)
            dc.start()

    def sigint_handler(self, signum, frame):
        self.log.debug("SIGINT received, stopping launchers...")
        self.stop_launchers()

    def start_logging(self):
        # Remove old log files of the controller and engine
        if self.clean_logs:
            log_dir = self.profile_dir.log_dir
            for f in os.listdir(log_dir):
                if re.match(r'ip(engine|controller)z-\d+\.(log|err|out)', f):
                    os.remove(os.path.join(log_dir, f))
        # This will remove old log files for ipcluster itself
        # super(IPBaseParallelApplication, self).start_logging()

    def start(self):
        """Start the app for the engines subcommand."""
        self.log.info("IPython cluster: started")
        # First see if the cluster is already running

        # Now log and daemonize
        self.log.info('Starting engines with [daemon=%r]' % self.daemonize)
        # TODO: Get daemonize working on Windows or as a Windows Server.
        if self.daemonize:
            if os.name == 'posix':
                daemonize()

        dc = ioloop.DelayedCallback(self.start_engines, 0, self.loop)
        dc.start()
        # Now write the new pid file AFTER our new forked pid is active.
        # self.write_pid_file()
        try:
            self.loop.start()
        except KeyboardInterrupt:
            pass
        except zmq.ZMQError as e:
            if e.errno == errno.EINTR:
                pass
            else:
                raise
예제 #22
0
class PDFExporter(LatexExporter):
    """Writer designed to write to PDF files"""

    latex_count = Integer(3, config=True,
        help="How many times latex will be called."
    )

    latex_command = List([u"pdflatex", u"{filename}"], config=True, 
        help="Shell command used to compile latex."
    )

    bib_command = List([u"bibtex", u"{filename}"], config=True,
        help="Shell command used to run bibtex."
    )

    verbose = Bool(False, config=True,
        help="Whether to display the output of latex commands."
    )

    temp_file_exts = List(['.aux', '.bbl', '.blg', '.idx', '.log', '.out'], config=True,
        help="File extensions of temp files to remove after running."
    )
    
    writer = Instance("IPython.nbconvert.writers.FilesWriter", args=())

    def run_command(self, command_list, filename, count, log_function):
        """Run command_list count times.
        
        Parameters
        ----------
        command_list : list
            A list of args to provide to Popen. Each element of this
            list will be interpolated with the filename to convert.
        filename : unicode
            The name of the file to convert.
        count : int
            How many times to run the command.
        
        Returns
        -------
        success : bool
            A boolean indicating if the command was successful (True)
            or failed (False).
        """
        command = [c.format(filename=filename) for c in command_list]
        #In windows and python 2.x there is a bug in subprocess.Popen and
        # unicode commands are not supported
        if sys.platform == 'win32' and sys.version_info < (3,0):
            #We must use cp1252 encoding for calling subprocess.Popen
            #Note that sys.stdin.encoding and encoding.DEFAULT_ENCODING
            # could be different (cp437 in case of dos console)
            command = [c.encode('cp1252') for c in command]        
        times = 'time' if count == 1 else 'times'
        self.log.info("Running %s %i %s: %s", command_list[0], count, times, command)
        with open(os.devnull, 'rb') as null:
            stdout = subprocess.PIPE if not self.verbose else None
            for index in range(count):
                p = subprocess.Popen(command, stdout=stdout, stdin=null)
                out, err = p.communicate()
                if p.returncode:
                    if self.verbose:
                        # verbose means I didn't capture stdout with PIPE,
                        # so it's already been displayed and `out` is None.
                        out = u''
                    else:
                        out = out.decode('utf-8', 'replace')
                    log_function(command, out)
                    return False # failure
        return True # success

    def run_latex(self, filename):
        """Run pdflatex self.latex_count times."""

        def log_error(command, out):
            self.log.critical(u"%s failed: %s\n%s", command[0], command, out)

        return self.run_command(self.latex_command, filename, 
            self.latex_count, log_error)

    def run_bib(self, filename):
        """Run bibtex self.latex_count times."""
        filename = os.path.splitext(filename)[0]

        def log_error(command, out):
            self.log.warn('%s had problems, most likely because there were no citations',
                command[0])
            self.log.debug(u"%s output: %s\n%s", command[0], command, out)

        return self.run_command(self.bib_command, filename, 1, log_error)

    def clean_temp_files(self, filename):
        """Remove temporary files created by pdflatex/bibtex."""
        self.log.info("Removing temporary LaTeX files")
        filename = os.path.splitext(filename)[0]
        for ext in self.temp_file_exts:
            try:
                os.remove(filename+ext)
            except OSError:
                pass

    def from_notebook_node(self, nb, resources=None, **kw):
        latex, resources = super(PDFExporter, self).from_notebook_node(
            nb, resources=resources, **kw
        )
        with TemporaryWorkingDirectory() as td:
            notebook_name = "notebook"
            tex_file = self.writer.write(latex, resources, notebook_name=notebook_name)
            self.log.info("Building PDF")
            rc = self.run_latex(tex_file)
            if not rc:
                rc = self.run_bib(tex_file)
            if not rc:
                rc = self.run_latex(tex_file)
            
            pdf_file = notebook_name + '.pdf'
            if not os.path.isfile(pdf_file):
                raise RuntimeError("PDF creating failed")
            self.log.info('PDF successfully created')
            with open(pdf_file, 'rb') as f:
                pdf_data = f.read()
        
        # convert output extension to pdf
        # the writer above required it to be tex
        resources['output_extension'] = 'pdf'
        
        return pdf_data, resources
예제 #23
0
class TaskScheduler(SessionFactory):
    """Python TaskScheduler object.

    This is the simplest object that supports msg_id based
    DAG dependencies. *Only* task msg_ids are checked, not
    msg_ids of jobs submitted via the MUX queue.

    """

    hwm = Integer(1,
                  config=True,
                  help="""specify the High Water Mark (HWM) for the downstream
        socket in the Task scheduler. This is the maximum number
        of allowed outstanding tasks on each engine.
        
        The default (1) means that only one task can be outstanding on each
        engine.  Setting TaskScheduler.hwm=0 means there is no limit, and the
        engines continue to be assigned tasks while they are working,
        effectively hiding network latency behind computation, but can result
        in an imbalance of work when submitting many heterogenous tasks all at
        once.  Any positive value greater than one is a compromise between the
        two.

        """)
    scheme_name = Enum(
        ('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
        'leastload',
        config=True,
        allow_none=False,
        help="""select the task scheduler scheme  [default: Python LRU]
        Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
    )

    def _scheme_name_changed(self, old, new):
        self.log.debug("Using scheme %r" % new)
        self.scheme = globals()[new]

    # input arguments:
    scheme = Instance(FunctionType)  # function for determining the destination

    def _scheme_default(self):
        return leastload

    client_stream = Instance(zmqstream.ZMQStream)  # client-facing stream
    engine_stream = Instance(zmqstream.ZMQStream)  # engine-facing stream
    notifier_stream = Instance(zmqstream.ZMQStream)  # hub-facing sub stream
    mon_stream = Instance(zmqstream.ZMQStream)  # hub-facing pub stream
    query_stream = Instance(zmqstream.ZMQStream)  # hub-facing DEALER stream

    # internals:
    queue = Instance(deque)  # sorted list of Jobs

    def _queue_default(self):
        return deque()

    queue_map = Dict()  # dict by msg_id of Jobs (for O(1) access to the Queue)
    graph = Dict()  # dict by msg_id of [ msg_ids that depend on key ]
    retries = Dict()  # dict by msg_id of retries remaining (non-neg ints)
    # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
    pending = Dict()  # dict by engine_uuid of submitted tasks
    completed = Dict()  # dict by engine_uuid of completed tasks
    failed = Dict()  # dict by engine_uuid of failed tasks
    destinations = Dict(
    )  # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
    clients = Dict()  # dict by msg_id for who submitted the task
    targets = List()  # list of target IDENTs
    loads = List()  # list of engine loads
    # full = Set() # set of IDENTs that have HWM outstanding tasks
    all_completed = Set()  # set of all completed tasks
    all_failed = Set()  # set of all failed tasks
    all_done = Set()  # set of all finished tasks=union(completed,failed)
    all_ids = Set()  # set of all submitted task IDs

    ident = CBytes()  # ZMQ identity. This should just be self.session.session

    # but ensure Bytes
    def _ident_default(self):
        return self.session.bsession

    def start(self):
        self.query_stream.on_recv(self.dispatch_query_reply)
        self.session.send(self.query_stream, "connection_request", {})

        self.engine_stream.on_recv(self.dispatch_result, copy=False)
        self.client_stream.on_recv(self.dispatch_submission, copy=False)

        self._notification_handlers = dict(
            registration_notification=self._register_engine,
            unregistration_notification=self._unregister_engine)
        self.notifier_stream.on_recv(self.dispatch_notification)
        self.log.info("Scheduler started [%s]" % self.scheme_name)

    def resume_receiving(self):
        """Resume accepting jobs."""
        self.client_stream.on_recv(self.dispatch_submission, copy=False)

    def stop_receiving(self):
        """Stop accepting jobs while there are no engines.
        Leave them in the ZMQ queue."""
        self.client_stream.on_recv(None)

    #-----------------------------------------------------------------------
    # [Un]Registration Handling
    #-----------------------------------------------------------------------

    def dispatch_query_reply(self, msg):
        """handle reply to our initial connection request"""
        try:
            idents, msg = self.session.feed_identities(msg)
        except ValueError:
            self.log.warn("task::Invalid Message: %r", msg)
            return
        try:
            msg = self.session.unserialize(msg)
        except ValueError:
            self.log.warn("task::Unauthorized message from: %r" % idents)
            return

        content = msg['content']
        for uuid in list(content.get('engines', {}).values()):
            self._register_engine(cast_bytes(uuid))

    @util.log_errors
    def dispatch_notification(self, msg):
        """dispatch register/unregister events."""
        try:
            idents, msg = self.session.feed_identities(msg)
        except ValueError:
            self.log.warn("task::Invalid Message: %r", msg)
            return
        try:
            msg = self.session.unserialize(msg)
        except ValueError:
            self.log.warn("task::Unauthorized message from: %r" % idents)
            return

        msg_type = msg['header']['msg_type']

        handler = self._notification_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("Unhandled message type: %r" % msg_type)
        else:
            try:
                handler(cast_bytes(msg['content']['uuid']))
            except Exception:
                self.log.error("task::Invalid notification msg: %r",
                               msg,
                               exc_info=True)

    def _register_engine(self, uid):
        """New engine with ident `uid` became available."""
        # head of the line:
        self.targets.insert(0, uid)
        self.loads.insert(0, 0)

        # initialize sets
        self.completed[uid] = set()
        self.failed[uid] = set()
        self.pending[uid] = {}

        # rescan the graph:
        self.update_graph(None)

    def _unregister_engine(self, uid):
        """Existing engine with ident `uid` became unavailable."""
        if len(self.targets) == 1:
            # this was our only engine
            pass

        # handle any potentially finished tasks:
        self.engine_stream.flush()

        # don't pop destinations, because they might be used later
        # map(self.destinations.pop, self.completed.pop(uid))
        # map(self.destinations.pop, self.failed.pop(uid))

        # prevent this engine from receiving work
        idx = self.targets.index(uid)
        self.targets.pop(idx)
        self.loads.pop(idx)

        # wait 5 seconds before cleaning up pending jobs, since the results might
        # still be incoming
        if self.pending[uid]:
            dc = ioloop.DelayedCallback(
                lambda: self.handle_stranded_tasks(uid), 5000, self.loop)
            dc.start()
        else:
            self.completed.pop(uid)
            self.failed.pop(uid)

    def handle_stranded_tasks(self, engine):
        """Deal with jobs resident in an engine that died."""
        lost = self.pending[engine]
        for msg_id in list(lost.keys()):
            if msg_id not in self.pending[engine]:
                # prevent double-handling of messages
                continue

            raw_msg = lost[msg_id].raw_msg
            idents, msg = self.session.feed_identities(raw_msg, copy=False)
            parent = self.session.unpack(msg[1].bytes)
            idents = [engine, idents[0]]

            # build fake error reply
            try:
                raise error.EngineError(
                    "Engine %r died while running task %r" % (engine, msg_id))
            except:
                content = error.wrap_exception()
            # build fake metadata
            md = dict(
                status='error',
                engine=engine.decode('ascii'),
                date=datetime.now(),
            )
            msg = self.session.msg('apply_reply',
                                   content,
                                   parent=parent,
                                   metadata=md)
            raw_reply = list(
                map(zmq.Message, self.session.serialize(msg, ident=idents)))
            # and dispatch it
            self.dispatch_result(raw_reply)

        # finally scrub completed/failed lists
        self.completed.pop(engine)
        self.failed.pop(engine)

    #-----------------------------------------------------------------------
    # Job Submission
    #-----------------------------------------------------------------------

    @util.log_errors
    def dispatch_submission(self, raw_msg):
        """Dispatch job submission to appropriate handlers."""
        # ensure targets up to date:
        self.notifier_stream.flush()
        try:
            idents, msg = self.session.feed_identities(raw_msg, copy=False)
            msg = self.session.unserialize(msg, content=False, copy=False)
        except Exception:
            self.log.error("task::Invaid task msg: %r" % raw_msg,
                           exc_info=True)
            return

        # send to monitor
        self.mon_stream.send_multipart([b'intask'] + raw_msg, copy=False)

        header = msg['header']
        md = msg['metadata']
        msg_id = header['msg_id']
        self.all_ids.add(msg_id)

        # get targets as a set of bytes objects
        # from a list of unicode objects
        targets = md.get('targets', [])
        targets = list(map(cast_bytes, targets))
        targets = set(targets)

        retries = md.get('retries', 0)
        self.retries[msg_id] = retries

        # time dependencies
        after = md.get('after', None)
        if after:
            after = Dependency(after)
            if after.all:
                if after.success:
                    after = Dependency(
                        after.difference(self.all_completed),
                        success=after.success,
                        failure=after.failure,
                        all=after.all,
                    )
                if after.failure:
                    after = Dependency(
                        after.difference(self.all_failed),
                        success=after.success,
                        failure=after.failure,
                        all=after.all,
                    )
            if after.check(self.all_completed, self.all_failed):
                # recast as empty set, if `after` already met,
                # to prevent unnecessary set comparisons
                after = MET
        else:
            after = MET

        # location dependencies
        follow = Dependency(md.get('follow', []))

        # turn timeouts into datetime objects:
        timeout = md.get('timeout', None)
        if timeout:
            timeout = time.time() + float(timeout)

        job = Job(
            msg_id=msg_id,
            raw_msg=raw_msg,
            idents=idents,
            msg=msg,
            header=header,
            targets=targets,
            after=after,
            follow=follow,
            timeout=timeout,
            metadata=md,
        )
        if timeout:
            # schedule timeout callback
            self.loop.add_timeout(timeout, lambda: self.job_timeout(job))

        # validate and reduce dependencies:
        for dep in after, follow:
            if not dep:  # empty dependency
                continue
            # check valid:
            if msg_id in dep or dep.difference(self.all_ids):
                self.queue_map[msg_id] = job
                return self.fail_unreachable(msg_id, error.InvalidDependency)
            # check if unreachable:
            if dep.unreachable(self.all_completed, self.all_failed):
                self.queue_map[msg_id] = job
                return self.fail_unreachable(msg_id)

        if after.check(self.all_completed, self.all_failed):
            # time deps already met, try to run
            if not self.maybe_run(job):
                # can't run yet
                if msg_id not in self.all_failed:
                    # could have failed as unreachable
                    self.save_unmet(job)
        else:
            self.save_unmet(job)

    def job_timeout(self, job):
        """callback for a job's timeout.
        
        The job may or may not have been run at this point.
        """
        now = time.time()
        if job.timeout >= (now + 1):
            self.log.warn("task %s timeout fired prematurely: %s > %s",
                          job.msg_id, job.timeout, now)
        if job.msg_id in self.queue_map:
            # still waiting, but ran out of time
            self.log.info("task %r timed out", job.msg_id)
            self.fail_unreachable(job.msg_id, error.TaskTimeout)

    def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
        """a task has become unreachable, send a reply with an ImpossibleDependency
        error."""
        if msg_id not in self.queue_map:
            self.log.error("task %r already failed!", msg_id)
            return
        job = self.queue_map.pop(msg_id)
        # lazy-delete from the queue
        job.removed = True
        for mid in job.dependents:
            if mid in self.graph:
                self.graph[mid].remove(msg_id)

        try:
            raise why()
        except:
            content = error.wrap_exception()
        self.log.debug("task %r failing as unreachable with: %s", msg_id,
                       content['ename'])

        self.all_done.add(msg_id)
        self.all_failed.add(msg_id)

        msg = self.session.send(self.client_stream,
                                'apply_reply',
                                content,
                                parent=job.header,
                                ident=job.idents)
        self.session.send(self.mon_stream,
                          msg,
                          ident=[b'outtask'] + job.idents)

        self.update_graph(msg_id, success=False)

    def available_engines(self):
        """return a list of available engine indices based on HWM"""
        if not self.hwm:
            return list(range(len(self.targets)))
        available = []
        for idx in range(len(self.targets)):
            if self.loads[idx] < self.hwm:
                available.append(idx)
        return available

    def maybe_run(self, job):
        """check location dependencies, and run if they are met."""
        msg_id = job.msg_id
        self.log.debug("Attempting to assign task %s", msg_id)
        available = self.available_engines()
        if not available:
            # no engines, definitely can't run
            return False

        if job.follow or job.targets or job.blacklist or self.hwm:
            # we need a can_run filter
            def can_run(idx):
                # check hwm
                if self.hwm and self.loads[idx] == self.hwm:
                    return False
                target = self.targets[idx]
                # check blacklist
                if target in job.blacklist:
                    return False
                # check targets
                if job.targets and target not in job.targets:
                    return False
                # check follow
                return job.follow.check(self.completed[target],
                                        self.failed[target])

            indices = list(filter(can_run, available))

            if not indices:
                # couldn't run
                if job.follow.all:
                    # check follow for impossibility
                    dests = set()
                    relevant = set()
                    if job.follow.success:
                        relevant = self.all_completed
                    if job.follow.failure:
                        relevant = relevant.union(self.all_failed)
                    for m in job.follow.intersection(relevant):
                        dests.add(self.destinations[m])
                    if len(dests) > 1:
                        self.queue_map[msg_id] = job
                        self.fail_unreachable(msg_id)
                        return False
                if job.targets:
                    # check blacklist+targets for impossibility
                    job.targets.difference_update(job.blacklist)
                    if not job.targets or not job.targets.intersection(
                            self.targets):
                        self.queue_map[msg_id] = job
                        self.fail_unreachable(msg_id)
                        return False
                return False
        else:
            indices = None

        self.submit_task(job, indices)
        return True

    def save_unmet(self, job):
        """Save a message for later submission when its dependencies are met."""
        msg_id = job.msg_id
        self.log.debug("Adding task %s to the queue", msg_id)
        self.queue_map[msg_id] = job
        self.queue.append(job)
        # track the ids in follow or after, but not those already finished
        for dep_id in job.after.union(job.follow).difference(self.all_done):
            if dep_id not in self.graph:
                self.graph[dep_id] = set()
            self.graph[dep_id].add(msg_id)

    def submit_task(self, job, indices=None):
        """Submit a task to any of a subset of our targets."""
        if indices:
            loads = [self.loads[i] for i in indices]
        else:
            loads = self.loads
        idx = self.scheme(loads)
        if indices:
            idx = indices[idx]
        target = self.targets[idx]
        # print (target, map(str, msg[:3]))
        # send job to the engine
        self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
        self.engine_stream.send_multipart(job.raw_msg, copy=False)
        # update load
        self.add_job(idx)
        self.pending[target][job.msg_id] = job
        # notify Hub
        content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
        self.session.send(self.mon_stream,
                          'task_destination',
                          content=content,
                          ident=[b'tracktask', self.ident])

    #-----------------------------------------------------------------------
    # Result Handling
    #-----------------------------------------------------------------------

    @util.log_errors
    def dispatch_result(self, raw_msg):
        """dispatch method for result replies"""
        try:
            idents, msg = self.session.feed_identities(raw_msg, copy=False)
            msg = self.session.unserialize(msg, content=False, copy=False)
            engine = idents[0]
            try:
                idx = self.targets.index(engine)
            except ValueError:
                pass  # skip load-update for dead engines
            else:
                self.finish_job(idx)
        except Exception:
            self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
            return

        md = msg['metadata']
        parent = msg['parent_header']
        if md.get('dependencies_met', True):
            success = (md['status'] == 'ok')
            msg_id = parent['msg_id']
            retries = self.retries[msg_id]
            if not success and retries > 0:
                # failed
                self.retries[msg_id] = retries - 1
                self.handle_unmet_dependency(idents, parent)
            else:
                del self.retries[msg_id]
                # relay to client and update graph
                self.handle_result(idents, parent, raw_msg, success)
                # send to Hub monitor
                self.mon_stream.send_multipart([b'outtask'] + raw_msg,
                                               copy=False)
        else:
            self.handle_unmet_dependency(idents, parent)

    def handle_result(self, idents, parent, raw_msg, success=True):
        """handle a real task result, either success or failure"""
        # first, relay result to client
        engine = idents[0]
        client = idents[1]
        # swap_ids for ROUTER-ROUTER mirror
        raw_msg[:2] = [client, engine]
        # print (map(str, raw_msg[:4]))
        self.client_stream.send_multipart(raw_msg, copy=False)
        # now, update our data structures
        msg_id = parent['msg_id']
        self.pending[engine].pop(msg_id)
        if success:
            self.completed[engine].add(msg_id)
            self.all_completed.add(msg_id)
        else:
            self.failed[engine].add(msg_id)
            self.all_failed.add(msg_id)
        self.all_done.add(msg_id)
        self.destinations[msg_id] = engine

        self.update_graph(msg_id, success)

    def handle_unmet_dependency(self, idents, parent):
        """handle an unmet dependency"""
        engine = idents[0]
        msg_id = parent['msg_id']

        job = self.pending[engine].pop(msg_id)
        job.blacklist.add(engine)

        if job.blacklist == job.targets:
            self.queue_map[msg_id] = job
            self.fail_unreachable(msg_id)
        elif not self.maybe_run(job):
            # resubmit failed
            if msg_id not in self.all_failed:
                # put it back in our dependency tree
                self.save_unmet(job)

        if self.hwm:
            try:
                idx = self.targets.index(engine)
            except ValueError:
                pass  # skip load-update for dead engines
            else:
                if self.loads[idx] == self.hwm - 1:
                    self.update_graph(None)

    def update_graph(self, dep_id=None, success=True):
        """dep_id just finished. Update our dependency
        graph and submit any jobs that just became runnable.

        Called with dep_id=None to update entire graph for hwm, but without finishing a task.
        """
        # print ("\n\n***********")
        # pprint (dep_id)
        # pprint (self.graph)
        # pprint (self.queue_map)
        # pprint (self.all_completed)
        # pprint (self.all_failed)
        # print ("\n\n***********\n\n")
        # update any jobs that depended on the dependency
        msg_ids = self.graph.pop(dep_id, [])

        # recheck *all* jobs if
        # a) we have HWM and an engine just become no longer full
        # or b) dep_id was given as None

        if dep_id is None or self.hwm and any(
            [load == self.hwm - 1 for load in self.loads]):
            jobs = self.queue
            using_queue = True
        else:
            using_queue = False
            jobs = deque(sorted(self.queue_map[msg_id] for msg_id in msg_ids))

        to_restore = []
        while jobs:
            job = jobs.popleft()
            if job.removed:
                continue
            msg_id = job.msg_id

            put_it_back = True

            if job.after.unreachable(self.all_completed, self.all_failed)\
                    or job.follow.unreachable(self.all_completed, self.all_failed):
                self.fail_unreachable(msg_id)
                put_it_back = False

            elif job.after.check(self.all_completed,
                                 self.all_failed):  # time deps met, maybe run
                if self.maybe_run(job):
                    put_it_back = False
                    self.queue_map.pop(msg_id)
                    for mid in job.dependents:
                        if mid in self.graph:
                            self.graph[mid].remove(msg_id)

                    # abort the loop if we just filled up all of our engines.
                    # avoids an O(N) operation in situation of full queue,
                    # where graph update is triggered as soon as an engine becomes
                    # non-full, and all tasks after the first are checked,
                    # even though they can't run.
                    if not self.available_engines():
                        break

            if using_queue and put_it_back:
                # popped a job from the queue but it neither ran nor failed,
                # so we need to put it back when we are done
                # make sure to_restore preserves the same ordering
                to_restore.append(job)

        # put back any tasks we popped but didn't run
        if using_queue:
            self.queue.extendleft(to_restore)

    #----------------------------------------------------------------------
    # methods to be overridden by subclasses
    #----------------------------------------------------------------------

    def add_job(self, idx):
        """Called after self.targets[idx] just got the job with header.
        Override with subclasses.  The default ordering is simple LRU.
        The default loads are the number of outstanding jobs."""
        self.loads[idx] += 1
        for lis in (self.targets, self.loads):
            lis.append(lis.pop(idx))

    def finish_job(self, idx):
        """Called after self.targets[idx] just finished a job.
        Override with subclasses."""
        self.loads[idx] -= 1
예제 #24
0
class IntegerTrait(HasTraits):
    value = Integer(1)
예제 #25
0
class ConnectionFileMixin(Configurable):
    """Mixin for configurable classes that work with connection files"""

    # The addresses for the communication channels
    connection_file = Unicode('')
    _connection_file_written = Bool(False)

    transport = CaselessStrEnum(['tcp', 'ipc'],
                                default_value='tcp',
                                config=True)

    ip = Unicode(LOCALHOST,
                 config=True,
                 help="""Set the kernel\'s IP address [default localhost].
        If the IP address is something other than localhost, then
        Consoles on other machines will be able to connect
        to the Kernel, so be careful!""")

    def _ip_default(self):
        if self.transport == 'ipc':
            if self.connection_file:
                return os.path.splitext(self.connection_file)[0] + '-ipc'
            else:
                return 'kernel-ipc'
        else:
            return LOCALHOST

    def _ip_changed(self, name, old, new):
        if new == '*':
            self.ip = '0.0.0.0'

    # protected traits

    shell_port = Integer(0)
    iopub_port = Integer(0)
    stdin_port = Integer(0)
    control_port = Integer(0)
    hb_port = Integer(0)

    @property
    def ports(self):
        return [getattr(self, name) for name in port_names]

    #--------------------------------------------------------------------------
    # Connection and ipc file management
    #--------------------------------------------------------------------------

    def get_connection_info(self):
        """return the connection info as a dict"""
        return dict(
            transport=self.transport,
            ip=self.ip,
            shell_port=self.shell_port,
            iopub_port=self.iopub_port,
            stdin_port=self.stdin_port,
            hb_port=self.hb_port,
            control_port=self.control_port,
            signature_scheme=self.session.signature_scheme,
            key=self.session.key,
        )

    def cleanup_connection_file(self):
        """Cleanup connection file *if we wrote it*

        Will not raise if the connection file was already removed somehow.
        """
        if self._connection_file_written:
            # cleanup connection files on full shutdown of kernel we started
            self._connection_file_written = False
            try:
                os.remove(self.connection_file)
            except (IOError, OSError, AttributeError):
                pass

    def cleanup_ipc_files(self):
        """Cleanup ipc files if we wrote them."""
        if self.transport != 'ipc':
            return
        for port in self.ports:
            ipcfile = "%s-%i" % (self.ip, port)
            try:
                os.remove(ipcfile)
            except (IOError, OSError):
                pass

    def write_connection_file(self):
        """Write connection info to JSON dict in self.connection_file."""
        if self._connection_file_written:
            return

        self.connection_file, cfg = write_connection_file(
            self.connection_file,
            transport=self.transport,
            ip=self.ip,
            key=self.session.key,
            stdin_port=self.stdin_port,
            iopub_port=self.iopub_port,
            shell_port=self.shell_port,
            hb_port=self.hb_port,
            control_port=self.control_port,
            signature_scheme=self.session.signature_scheme,
        )
        # write_connection_file also sets default ports:
        for name in port_names:
            setattr(self, name, cfg[name])

        self._connection_file_written = True

    def load_connection_file(self):
        """Load connection info from JSON dict in self.connection_file."""
        with open(self.connection_file) as f:
            cfg = json.loads(f.read())

        self.transport = cfg.get('transport', 'tcp')
        self.ip = cfg['ip']
        for name in port_names:
            setattr(self, name, cfg[name])
        if 'key' in cfg:
            self.session.key = str_to_bytes(cfg['key'])
        if cfg.get('signature_scheme'):
            self.session.signature_scheme = cfg['signature_scheme']

    #--------------------------------------------------------------------------
    # Creating connected sockets
    #--------------------------------------------------------------------------

    def _make_url(self, channel):
        """Make a ZeroMQ URL for a given channel."""
        transport = self.transport
        ip = self.ip
        port = getattr(self, '%s_port' % channel)

        if transport == 'tcp':
            return "tcp://%s:%i" % (ip, port)
        else:
            return "%s://%s-%s" % (transport, ip, port)

    def _create_connected_socket(self, channel, identity=None):
        """Create a zmq Socket and connect it to the kernel."""
        url = self._make_url(channel)
        socket_type = channel_socket_types[channel]
        self.log.info("Connecting to: %s" % url)
        sock = self.context.socket(socket_type)
        if identity:
            sock.identity = identity
        sock.connect(url)
        return sock

    def connect_iopub(self, identity=None):
        """return zmq Socket connected to the IOPub channel"""
        sock = self._create_connected_socket('iopub', identity=identity)
        sock.setsockopt(zmq.SUBSCRIBE, b'')
        return sock

    def connect_shell(self, identity=None):
        """return zmq Socket connected to the Shell channel"""
        return self._create_connected_socket('shell', identity=identity)

    def connect_stdin(self, identity=None):
        """return zmq Socket connected to the StdIn channel"""
        return self._create_connected_socket('stdin', identity=identity)

    def connect_hb(self, identity=None):
        """return zmq Socket connected to the Heartbeat channel"""
        return self._create_connected_socket('hb', identity=identity)

    def connect_control(self, identity=None):
        """return zmq Socket connected to the Heartbeat channel"""
        return self._create_connected_socket('control', identity=identity)
예제 #26
0
class Kernel(Configurable):

    #---------------------------------------------------------------------------
    # Kernel interface
    #---------------------------------------------------------------------------

    # attribute to override with a GUI
    eventloop = Any(None)

    def _eventloop_changed(self, name, old, new):
        """schedule call to eventloop from IOLoop"""
        loop = ioloop.IOLoop.instance()
        loop.add_timeout(time.time() + 0.1, self.enter_eventloop)

    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
    shell_class = Type(ZMQInteractiveShell)

    session = Instance(Session)
    profile_dir = Instance('IPython.core.profiledir.ProfileDir')
    shell_streams = List()
    control_stream = Instance(ZMQStream)
    iopub_socket = Instance(zmq.Socket)
    stdin_socket = Instance(zmq.Socket)
    log = Instance(logging.Logger)

    user_module = Any()

    def _user_module_changed(self, name, old, new):
        if self.shell is not None:
            self.shell.user_module = new

    user_ns = Instance(dict, args=None, allow_none=True)

    def _user_ns_changed(self, name, old, new):
        if self.shell is not None:
            self.shell.user_ns = new
            self.shell.init_user_ns()

    # identities:
    int_id = Integer(-1)
    ident = Unicode()

    def _ident_default(self):
        return unicode_type(uuid.uuid4())

    # Private interface

    # Time to sleep after flushing the stdout/err buffers in each execute
    # cycle.  While this introduces a hard limit on the minimal latency of the
    # execute cycle, it helps prevent output synchronization problems for
    # clients.
    # Units are in seconds.  The minimum zmq latency on local host is probably
    # ~150 microseconds, set this to 500us for now.  We may need to increase it
    # a little if it's not enough after more interactive testing.
    _execute_sleep = Float(0.0005, config=True)

    # Frequency of the kernel's event loop.
    # Units are in seconds, kernel subclasses for GUI toolkits may need to
    # adapt to milliseconds.
    _poll_interval = Float(0.05, config=True)

    # If the shutdown was requested over the network, we leave here the
    # necessary reply message so it can be sent by our registered atexit
    # handler.  This ensures that the reply is only sent to clients truly at
    # the end of our shutdown process (which happens after the underlying
    # IPython shell's own shutdown).
    _shutdown_message = None

    # This is a dict of port number that the kernel is listening on. It is set
    # by record_ports and used by connect_request.
    _recorded_ports = Dict()

    # A reference to the Python builtin 'raw_input' function.
    # (i.e., __builtin__.raw_input for Python 2.7, builtins.input for Python 3)
    _sys_raw_input = Any()
    _sys_eval_input = Any()

    # set of aborted msg_ids
    aborted = Set()

    def __init__(self, **kwargs):
        super(Kernel, self).__init__(**kwargs)

        # Initialize the InteractiveShell subclass
        self.shell = self.shell_class.instance(
            parent=self,
            profile_dir=self.profile_dir,
            user_module=self.user_module,
            user_ns=self.user_ns,
            kernel=self,
        )
        self.shell.displayhook.session = self.session
        self.shell.displayhook.pub_socket = self.iopub_socket
        self.shell.displayhook.topic = self._topic('pyout')
        self.shell.display_pub.session = self.session
        self.shell.display_pub.pub_socket = self.iopub_socket
        self.shell.data_pub.session = self.session
        self.shell.data_pub.pub_socket = self.iopub_socket

        # TMP - hack while developing
        self.shell._reply_content = None

        # Build dict of handlers for message types
        msg_types = [
            'execute_request',
            'complete_request',
            'object_info_request',
            'history_request',
            'kernel_info_request',
            'connect_request',
            'shutdown_request',
            'apply_request',
        ]
        self.shell_handlers = {}
        for msg_type in msg_types:
            self.shell_handlers[msg_type] = getattr(self, msg_type)

        comm_msg_types = ['comm_open', 'comm_msg', 'comm_close']
        comm_manager = self.shell.comm_manager
        for msg_type in comm_msg_types:
            self.shell_handlers[msg_type] = getattr(comm_manager, msg_type)

        control_msg_types = msg_types + ['clear_request', 'abort_request']
        self.control_handlers = {}
        for msg_type in control_msg_types:
            self.control_handlers[msg_type] = getattr(self, msg_type)

    def dispatch_control(self, msg):
        """dispatch control requests"""
        idents, msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.unserialize(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Control Message", exc_info=True)
            return

        self.log.debug("Control received: %s", msg)

        header = msg['header']
        msg_id = header['msg_id']
        msg_type = header['msg_type']

        handler = self.control_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type)
        else:
            try:
                handler(self.control_stream, idents, msg)
            except Exception:
                self.log.error("Exception in control handler:", exc_info=True)

    def dispatch_shell(self, stream, msg):
        """dispatch shell requests"""
        # flush control requests first
        if self.control_stream:
            self.control_stream.flush()

        idents, msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.unserialize(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Message", exc_info=True)
            return

        header = msg['header']
        msg_id = header['msg_id']
        msg_type = msg['header']['msg_type']

        # Print some info about this message and leave a '--->' marker, so it's
        # easier to trace visually the message chain when debugging.  Each
        # handler prints its message at the end.
        self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type)
        self.log.debug('   Content: %s\n   --->\n   ', msg['content'])

        if msg_id in self.aborted:
            self.aborted.remove(msg_id)
            # is it safe to assume a msg_id will not be resubmitted?
            reply_type = msg_type.split('_')[0] + '_reply'
            status = {'status': 'aborted'}
            md = {'engine': self.ident}
            md.update(status)
            reply_msg = self.session.send(stream,
                                          reply_type,
                                          metadata=md,
                                          content=status,
                                          parent=msg,
                                          ident=idents)
            return

        handler = self.shell_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("UNKNOWN MESSAGE TYPE: %r", msg_type)
        else:
            # ensure default_int_handler during handler call
            sig = signal(SIGINT, default_int_handler)
            try:
                handler(stream, idents, msg)
            except Exception:
                self.log.error("Exception in message handler:", exc_info=True)
            finally:
                signal(SIGINT, sig)

    def enter_eventloop(self):
        """enter eventloop"""
        self.log.info("entering eventloop")
        # restore default_int_handler
        signal(SIGINT, default_int_handler)
        while self.eventloop is not None:
            try:
                self.eventloop(self)
            except KeyboardInterrupt:
                # Ctrl-C shouldn't crash the kernel
                self.log.error("KeyboardInterrupt caught in kernel")
                continue
            else:
                # eventloop exited cleanly, this means we should stop (right?)
                self.eventloop = None
                break
        self.log.info("exiting eventloop")

    def start(self):
        """register dispatchers for streams"""
        self.shell.exit_now = False
        if self.control_stream:
            self.control_stream.on_recv(self.dispatch_control, copy=False)

        def make_dispatcher(stream):
            def dispatcher(msg):
                return self.dispatch_shell(stream, msg)

            return dispatcher

        for s in self.shell_streams:
            s.on_recv(make_dispatcher(s), copy=False)

        # publish idle status
        self._publish_status('starting')

    def do_one_iteration(self):
        """step eventloop just once"""
        if self.control_stream:
            self.control_stream.flush()
        for stream in self.shell_streams:
            # handle at most one request per iteration
            stream.flush(zmq.POLLIN, 1)
            stream.flush(zmq.POLLOUT)

    def record_ports(self, ports):
        """Record the ports that this kernel is using.

        The creator of the Kernel instance must call this methods if they
        want the :meth:`connect_request` method to return the port numbers.
        """
        self._recorded_ports = ports

    #---------------------------------------------------------------------------
    # Kernel request handlers
    #---------------------------------------------------------------------------

    def _make_metadata(self, other=None):
        """init metadata dict, for execute/apply_reply"""
        new_md = {
            'dependencies_met': True,
            'engine': self.ident,
            'started': datetime.now(),
        }
        if other:
            new_md.update(other)
        return new_md

    def _publish_pyin(self, code, parent, execution_count):
        """Publish the code request on the pyin stream."""

        self.session.send(self.iopub_socket,
                          u'pyin', {
                              u'code': code,
                              u'execution_count': execution_count
                          },
                          parent=parent,
                          ident=self._topic('pyin'))

    def _publish_status(self, status, parent=None):
        """send status (busy/idle) on IOPub"""
        self.session.send(
            self.iopub_socket,
            u'status',
            {u'execution_state': status},
            parent=parent,
            ident=self._topic('status'),
        )

    def execute_request(self, stream, ident, parent):
        """handle an execute_request"""

        self._publish_status(u'busy', parent)

        try:
            content = parent[u'content']
            code = content[u'code']
            silent = content[u'silent']
            store_history = content.get(u'store_history', not silent)
        except:
            self.log.error("Got bad msg: ")
            self.log.error("%s", parent)
            return

        md = self._make_metadata(parent['metadata'])

        shell = self.shell  # we'll need this a lot here

        # Replace raw_input. Note that is not sufficient to replace
        # raw_input in the user namespace.
        if content.get('allow_stdin', False):
            raw_input = lambda prompt='': self._raw_input(
                prompt, ident, parent)
            input = lambda prompt='': eval(raw_input(prompt))
        else:
            raw_input = input = lambda prompt='': self._no_raw_input()

        if py3compat.PY3:
            self._sys_raw_input = builtin_mod.input
            builtin_mod.input = raw_input
        else:
            self._sys_raw_input = builtin_mod.raw_input
            self._sys_eval_input = builtin_mod.input
            builtin_mod.raw_input = raw_input
            builtin_mod.input = input

        # Set the parent message of the display hook and out streams.
        shell.set_parent(parent)

        # Re-broadcast our input for the benefit of listening clients, and
        # start computing output
        if not silent:
            self._publish_pyin(code, parent, shell.execution_count)

        reply_content = {}
        try:
            # FIXME: the shell calls the exception handler itself.
            shell.run_cell(code, store_history=store_history, silent=silent)
        except:
            status = u'error'
            # FIXME: this code right now isn't being used yet by default,
            # because the run_cell() call above directly fires off exception
            # reporting.  This code, therefore, is only active in the scenario
            # where runlines itself has an unhandled exception.  We need to
            # uniformize this, for all exception construction to come from a
            # single location in the codbase.
            etype, evalue, tb = sys.exc_info()
            tb_list = traceback.format_exception(etype, evalue, tb)
            reply_content.update(shell._showtraceback(etype, evalue, tb_list))
        else:
            status = u'ok'
        finally:
            # Restore raw_input.
            if py3compat.PY3:
                builtin_mod.input = self._sys_raw_input
            else:
                builtin_mod.raw_input = self._sys_raw_input
                builtin_mod.input = self._sys_eval_input

        reply_content[u'status'] = status

        # Return the execution counter so clients can display prompts
        reply_content['execution_count'] = shell.execution_count - 1

        # FIXME - fish exception info out of shell, possibly left there by
        # runlines.  We'll need to clean up this logic later.
        if shell._reply_content is not None:
            reply_content.update(shell._reply_content)
            e_info = dict(engine_uuid=self.ident,
                          engine_id=self.int_id,
                          method='execute')
            reply_content['engine_info'] = e_info
            # reset after use
            shell._reply_content = None

        if 'traceback' in reply_content:
            self.log.info("Exception in execute request:\n%s",
                          '\n'.join(reply_content['traceback']))

        # At this point, we can tell whether the main code execution succeeded
        # or not.  If it did, we proceed to evaluate user_variables/expressions
        if reply_content['status'] == 'ok':
            reply_content[u'user_variables'] = \
                         shell.user_variables(content.get(u'user_variables', []))
            reply_content[u'user_expressions'] = \
                         shell.user_expressions(content.get(u'user_expressions', {}))
        else:
            # If there was an error, don't even try to compute variables or
            # expressions
            reply_content[u'user_variables'] = {}
            reply_content[u'user_expressions'] = {}

        # Payloads should be retrieved regardless of outcome, so we can both
        # recover partial output (that could have been generated early in a
        # block, before an error) and clear the payload system always.
        reply_content[u'payload'] = shell.payload_manager.read_payload()
        # Be agressive about clearing the payload because we don't want
        # it to sit in memory until the next execute_request comes in.
        shell.payload_manager.clear_payload()

        # Flush output before sending the reply.
        sys.stdout.flush()
        sys.stderr.flush()
        # FIXME: on rare occasions, the flush doesn't seem to make it to the
        # clients... This seems to mitigate the problem, but we definitely need
        # to better understand what's going on.
        if self._execute_sleep:
            time.sleep(self._execute_sleep)

        # Send the reply.
        reply_content = json_clean(reply_content)

        md['status'] = reply_content['status']
        if reply_content['status'] == 'error' and \
                        reply_content['ename'] == 'UnmetDependency':
            md['dependencies_met'] = False

        reply_msg = self.session.send(stream,
                                      u'execute_reply',
                                      reply_content,
                                      parent,
                                      metadata=md,
                                      ident=ident)

        self.log.debug("%s", reply_msg)

        if not silent and reply_msg['content']['status'] == u'error':
            self._abort_queues()

        self._publish_status(u'idle', parent)

    def complete_request(self, stream, ident, parent):
        txt, matches = self._complete(parent)
        matches = {'matches': matches, 'matched_text': txt, 'status': 'ok'}
        matches = json_clean(matches)
        completion_msg = self.session.send(stream, 'complete_reply', matches,
                                           parent, ident)
        self.log.debug("%s", completion_msg)

    def object_info_request(self, stream, ident, parent):
        content = parent['content']
        object_info = self.shell.object_inspect(content['oname'],
                                                detail_level=content.get(
                                                    'detail_level', 0))
        # Before we send this object over, we scrub it for JSON usage
        oinfo = json_clean(object_info)
        msg = self.session.send(stream, 'object_info_reply', oinfo, parent,
                                ident)
        self.log.debug("%s", msg)

    def history_request(self, stream, ident, parent):
        # We need to pull these out, as passing **kwargs doesn't work with
        # unicode keys before Python 2.6.5.
        hist_access_type = parent['content']['hist_access_type']
        raw = parent['content']['raw']
        output = parent['content']['output']
        if hist_access_type == 'tail':
            n = parent['content']['n']
            hist = self.shell.history_manager.get_tail(n,
                                                       raw=raw,
                                                       output=output,
                                                       include_latest=True)

        elif hist_access_type == 'range':
            session = parent['content']['session']
            start = parent['content']['start']
            stop = parent['content']['stop']
            hist = self.shell.history_manager.get_range(session,
                                                        start,
                                                        stop,
                                                        raw=raw,
                                                        output=output)

        elif hist_access_type == 'search':
            n = parent['content'].get('n')
            unique = parent['content'].get('unique', False)
            pattern = parent['content']['pattern']
            hist = self.shell.history_manager.search(pattern,
                                                     raw=raw,
                                                     output=output,
                                                     n=n,
                                                     unique=unique)

        else:
            hist = []
        hist = list(hist)
        content = {'history': hist}
        content = json_clean(content)
        msg = self.session.send(stream, 'history_reply', content, parent,
                                ident)
        self.log.debug("Sending history reply with %i entries", len(hist))

    def connect_request(self, stream, ident, parent):
        if self._recorded_ports is not None:
            content = self._recorded_ports.copy()
        else:
            content = {}
        msg = self.session.send(stream, 'connect_reply', content, parent,
                                ident)
        self.log.debug("%s", msg)

    def kernel_info_request(self, stream, ident, parent):
        vinfo = {
            'protocol_version': protocol_version,
            'ipython_version': ipython_version,
            'language_version': language_version,
            'language': 'python',
        }
        msg = self.session.send(stream, 'kernel_info_reply', vinfo, parent,
                                ident)
        self.log.debug("%s", msg)

    def shutdown_request(self, stream, ident, parent):
        self.shell.exit_now = True
        content = dict(status='ok')
        content.update(parent['content'])
        self.session.send(stream,
                          u'shutdown_reply',
                          content,
                          parent,
                          ident=ident)
        # same content, but different msg_id for broadcasting on IOPub
        self._shutdown_message = self.session.msg(u'shutdown_reply', content,
                                                  parent)

        self._at_shutdown()
        # call sys.exit after a short delay
        loop = ioloop.IOLoop.instance()
        loop.add_timeout(time.time() + 0.1, loop.stop)

    #---------------------------------------------------------------------------
    # Engine methods
    #---------------------------------------------------------------------------

    def apply_request(self, stream, ident, parent):
        try:
            content = parent[u'content']
            bufs = parent[u'buffers']
            msg_id = parent['header']['msg_id']
        except:
            self.log.error("Got bad msg: %s", parent, exc_info=True)
            return

        self._publish_status(u'busy', parent)

        # Set the parent message of the display hook and out streams.
        shell = self.shell
        shell.set_parent(parent)

        # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
        # self.iopub_socket.send(pyin_msg)
        # self.session.send(self.iopub_socket, u'pyin', {u'code':code},parent=parent)
        md = self._make_metadata(parent['metadata'])
        try:
            working = shell.user_ns

            prefix = "_" + str(msg_id).replace("-", "") + "_"

            f, args, kwargs = unpack_apply_message(bufs, working, copy=False)

            fname = getattr(f, '__name__', 'f')

            fname = prefix + "f"
            argname = prefix + "args"
            kwargname = prefix + "kwargs"
            resultname = prefix + "result"

            ns = {fname: f, argname: args, kwargname: kwargs, resultname: None}
            # print ns
            working.update(ns)
            code = "%s = %s(*%s,**%s)" % (resultname, fname, argname,
                                          kwargname)
            try:
                exec(code, shell.user_global_ns, shell.user_ns)
                result = working.get(resultname)
            finally:
                for key in ns:
                    working.pop(key)

            result_buf = serialize_object(
                result,
                buffer_threshold=self.session.buffer_threshold,
                item_threshold=self.session.item_threshold,
            )

        except:
            # invoke IPython traceback formatting
            shell.showtraceback()
            # FIXME - fish exception info out of shell, possibly left there by
            # run_code.  We'll need to clean up this logic later.
            reply_content = {}
            if shell._reply_content is not None:
                reply_content.update(shell._reply_content)
                e_info = dict(engine_uuid=self.ident,
                              engine_id=self.int_id,
                              method='apply')
                reply_content['engine_info'] = e_info
                # reset after use
                shell._reply_content = None

            self.session.send(self.iopub_socket,
                              u'pyerr',
                              reply_content,
                              parent=parent,
                              ident=self._topic('pyerr'))
            self.log.info("Exception in apply request:\n%s",
                          '\n'.join(reply_content['traceback']))
            result_buf = []

            if reply_content['ename'] == 'UnmetDependency':
                md['dependencies_met'] = False
        else:
            reply_content = {'status': 'ok'}

        # put 'ok'/'error' status in header, for scheduler introspection:
        md['status'] = reply_content['status']

        # flush i/o
        sys.stdout.flush()
        sys.stderr.flush()

        reply_msg = self.session.send(stream,
                                      u'apply_reply',
                                      reply_content,
                                      parent=parent,
                                      ident=ident,
                                      buffers=result_buf,
                                      metadata=md)

        self._publish_status(u'idle', parent)

    #---------------------------------------------------------------------------
    # Control messages
    #---------------------------------------------------------------------------

    def abort_request(self, stream, ident, parent):
        """abort a specifig msg by id"""
        msg_ids = parent['content'].get('msg_ids', None)
        if isinstance(msg_ids, string_types):
            msg_ids = [msg_ids]
        if not msg_ids:
            self.abort_queues()
        for mid in msg_ids:
            self.aborted.add(str(mid))

        content = dict(status='ok')
        reply_msg = self.session.send(stream,
                                      'abort_reply',
                                      content=content,
                                      parent=parent,
                                      ident=ident)
        self.log.debug("%s", reply_msg)

    def clear_request(self, stream, idents, parent):
        """Clear our namespace."""
        self.shell.reset(False)
        msg = self.session.send(stream,
                                'clear_reply',
                                ident=idents,
                                parent=parent,
                                content=dict(status='ok'))

    #---------------------------------------------------------------------------
    # Protected interface
    #---------------------------------------------------------------------------

    def _wrap_exception(self, method=None):
        # import here, because _wrap_exception is only used in parallel,
        # and parallel has higher min pyzmq version
        from IPython.parallel.error import wrap_exception
        e_info = dict(engine_uuid=self.ident,
                      engine_id=self.int_id,
                      method=method)
        content = wrap_exception(e_info)
        return content

    def _topic(self, topic):
        """prefixed topic for IOPub messages"""
        if self.int_id >= 0:
            base = "engine.%i" % self.int_id
        else:
            base = "kernel.%s" % self.ident

        return py3compat.cast_bytes("%s.%s" % (base, topic))

    def _abort_queues(self):
        for stream in self.shell_streams:
            if stream:
                self._abort_queue(stream)

    def _abort_queue(self, stream):
        poller = zmq.Poller()
        poller.register(stream.socket, zmq.POLLIN)
        while True:
            idents, msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
            if msg is None:
                return

            self.log.info("Aborting:")
            self.log.info("%s", msg)
            msg_type = msg['header']['msg_type']
            reply_type = msg_type.split('_')[0] + '_reply'

            status = {'status': 'aborted'}
            md = {'engine': self.ident}
            md.update(status)
            reply_msg = self.session.send(stream,
                                          reply_type,
                                          metadata=md,
                                          content=status,
                                          parent=msg,
                                          ident=idents)
            self.log.debug("%s", reply_msg)
            # We need to wait a bit for requests to come in. This can probably
            # be set shorter for true asynchronous clients.
            poller.poll(50)

    def _no_raw_input(self):
        """Raise StdinNotImplentedError if active frontend doesn't support
        stdin."""
        raise StdinNotImplementedError("raw_input was called, but this "
                                       "frontend does not support stdin.")

    def _raw_input(self, prompt, ident, parent):
        # Flush output before making the request.
        sys.stderr.flush()
        sys.stdout.flush()
        # flush the stdin socket, to purge stale replies
        while True:
            try:
                self.stdin_socket.recv_multipart(zmq.NOBLOCK)
            except zmq.ZMQError as e:
                if e.errno == zmq.EAGAIN:
                    break
                else:
                    raise

        # Send the input request.
        content = json_clean(dict(prompt=prompt))
        self.session.send(self.stdin_socket,
                          u'input_request',
                          content,
                          parent,
                          ident=ident)

        # Await a response.
        while True:
            try:
                ident, reply = self.session.recv(self.stdin_socket, 0)
            except Exception:
                self.log.warn("Invalid Message:", exc_info=True)
            except KeyboardInterrupt:
                # re-raise KeyboardInterrupt, to truncate traceback
                raise KeyboardInterrupt
            else:
                break
        try:
            value = py3compat.unicode_to_str(reply['content']['value'])
        except:
            self.log.error("Got bad raw_input reply: ")
            self.log.error("%s", parent)
            value = ''
        if value == '\x04':
            # EOF
            raise EOFError
        return value

    def _complete(self, msg):
        c = msg['content']
        try:
            cpos = int(c['cursor_pos'])
        except:
            # If we don't get something that we can convert to an integer, at
            # least attempt the completion guessing the cursor is at the end of
            # the text, if there's any, and otherwise of the line
            cpos = len(c['text'])
            if cpos == 0:
                cpos = len(c['line'])
        return self.shell.complete(c['text'], c['line'], cpos)

    def _at_shutdown(self):
        """Actions taken at shutdown by the kernel, called by python's atexit.
        """
        # io.rprint("Kernel at_shutdown") # dbg
        if self._shutdown_message is not None:
            self.session.send(self.iopub_socket,
                              self._shutdown_message,
                              ident=self._topic('shutdown'))
            self.log.debug("%s", self._shutdown_message)
        [s.flush(zmq.POLLOUT) for s in self.shell_streams]
예제 #27
0
class PlainTextFormatter(BaseFormatter):
    """The default pretty-printer.

    This uses :mod:`IPython.lib.pretty` to compute the format data of
    the object. If the object cannot be pretty printed, :func:`repr` is used.
    See the documentation of :mod:`IPython.lib.pretty` for details on
    how to write pretty printers.  Here is a simple example::

        def dtype_pprinter(obj, p, cycle):
            if cycle:
                return p.text('dtype(...)')
            if hasattr(obj, 'fields'):
                if obj.fields is None:
                    p.text(repr(obj))
                else:
                    p.begin_group(7, 'dtype([')
                    for i, field in enumerate(obj.descr):
                        if i > 0:
                            p.text(',')
                            p.breakable()
                        p.pretty(field)
                    p.end_group(7, '])')
    """

    # The format type of data returned.
    format_type = Unicode('text/plain')

    # This subclass ignores this attribute as it always need to return
    # something.
    enabled = Bool(True, config=False)

    # Look for a _repr_pretty_ methods to use for pretty printing.
    print_method = ObjectName('_repr_pretty_')

    # Whether to pretty-print or not.
    pprint = Bool(True, config=True)

    # Whether to be verbose or not.
    verbose = Bool(False, config=True)

    # The maximum width.
    max_width = Integer(79, config=True)

    # The newline character.
    newline = Unicode('\n', config=True)

    # format-string for pprinting floats
    float_format = Unicode('%r')
    # setter for float precision, either int or direct format-string
    float_precision = CUnicode('', config=True)

    def _float_precision_changed(self, name, old, new):
        """float_precision changed, set float_format accordingly.

        float_precision can be set by int or str.
        This will set float_format, after interpreting input.
        If numpy has been imported, numpy print precision will also be set.

        integer `n` sets format to '%.nf', otherwise, format set directly.

        An empty string returns to defaults (repr for float, 8 for numpy).

        This parameter can be set via the '%precision' magic.
        """

        if '%' in new:
            # got explicit format string
            fmt = new
            try:
                fmt % 3.14159
            except Exception:
                raise ValueError(
                    "Precision must be int or format string, not %r" % new)
        elif new:
            # otherwise, should be an int
            try:
                i = int(new)
                assert i >= 0
            except ValueError:
                raise ValueError(
                    "Precision must be int or format string, not %r" % new)
            except AssertionError:
                raise ValueError("int precision must be non-negative, not %r" %
                                 i)

            fmt = '%%.%if' % i
            if 'numpy' in sys.modules:
                # set numpy precision if it has been imported
                import numpy
                numpy.set_printoptions(precision=i)
        else:
            # default back to repr
            fmt = '%r'
            if 'numpy' in sys.modules:
                import numpy
                # numpy default is 8
                numpy.set_printoptions(precision=8)
        self.float_format = fmt

    # Use the default pretty printers from IPython.lib.pretty.
    def _singleton_printers_default(self):
        return pretty._singleton_pprinters.copy()

    def _type_printers_default(self):
        d = pretty._type_pprinters.copy()
        d[float] = lambda obj, p, cycle: p.text(self.float_format % obj)
        return d

    def _deferred_printers_default(self):
        return pretty._deferred_type_pprinters.copy()

    #### FormatterABC interface ####

    @warn_format_error
    def __call__(self, obj):
        """Compute the pretty representation of the object."""
        if not self.pprint:
            return pretty._safe_repr(obj)
        else:
            # This uses use StringIO, as cStringIO doesn't handle unicode.
            stream = StringIO()
            # self.newline.encode() is a quick fix for issue gh-597. We need to
            # ensure that stream does not get a mix of unicode and bytestrings,
            # or it will cause trouble.
            printer = pretty.RepresentationPrinter(
                stream,
                self.verbose,
                self.max_width,
                unicode_to_str(self.newline),
                singleton_pprinters=self.singleton_printers,
                type_pprinters=self.type_printers,
                deferred_pprinters=self.deferred_printers)
            printer.pretty(obj)
            printer.flush()
            return stream.getvalue()
예제 #28
0
class ConnectionFileMixin(LoggingConfigurable):
    """Mixin for configurable classes that work with connection files"""

    # The addresses for the communication channels
    connection_file = Unicode(
        '',
        config=True,
        help=
        """JSON file in which to store connection info [default: kernel-<pid>.json]

    This file will contain the IP, ports, and authentication key needed to connect
    clients to this kernel. By default, this file will be created in the security dir
    of the current profile, but can be specified by absolute path.
    """)
    _connection_file_written = Bool(False)

    transport = CaselessStrEnum(['tcp', 'ipc'],
                                default_value='tcp',
                                config=True)

    ip = Unicode(config=True,
                 help="""Set the kernel\'s IP address [default localhost].
        If the IP address is something other than localhost, then
        Consoles on other machines will be able to connect
        to the Kernel, so be careful!""")

    def _ip_default(self):
        if self.transport == 'ipc':
            if self.connection_file:
                return os.path.splitext(self.connection_file)[0] + '-ipc'
            else:
                return 'kernel-ipc'
        else:
            return localhost()

    def _ip_changed(self, name, old, new):
        if new == '*':
            self.ip = '0.0.0.0'

    # protected traits

    hb_port = Integer(0,
                      config=True,
                      help="set the heartbeat port [default: random]")
    shell_port = Integer(0,
                         config=True,
                         help="set the shell (ROUTER) port [default: random]")
    iopub_port = Integer(0,
                         config=True,
                         help="set the iopub (PUB) port [default: random]")
    stdin_port = Integer(0,
                         config=True,
                         help="set the stdin (ROUTER) port [default: random]")
    control_port = Integer(
        0, config=True, help="set the control (ROUTER) port [default: random]")

    @property
    def ports(self):
        return [getattr(self, name) for name in port_names]

    # The Session to use for communication with the kernel.
    session = Instance('jupyter_client.session.Session')

    def _session_default(self):
        from jupyter_client.session import Session
        return Session(parent=self)

    #--------------------------------------------------------------------------
    # Connection and ipc file management
    #--------------------------------------------------------------------------

    def get_connection_info(self):
        """return the connection info as a dict"""
        return dict(
            transport=self.transport,
            ip=self.ip,
            shell_port=self.shell_port,
            iopub_port=self.iopub_port,
            stdin_port=self.stdin_port,
            hb_port=self.hb_port,
            control_port=self.control_port,
            signature_scheme=self.session.signature_scheme,
            key=self.session.key,
        )

    def cleanup_connection_file(self):
        """Cleanup connection file *if we wrote it*

        Will not raise if the connection file was already removed somehow.
        """
        if self._connection_file_written:
            # cleanup connection files on full shutdown of kernel we started
            self._connection_file_written = False
            try:
                os.remove(self.connection_file)
            except (IOError, OSError, AttributeError):
                pass

    def cleanup_ipc_files(self):
        """Cleanup ipc files if we wrote them."""
        if self.transport != 'ipc':
            return
        for port in self.ports:
            ipcfile = "%s-%i" % (self.ip, port)
            try:
                os.remove(ipcfile)
            except (IOError, OSError):
                pass

    def write_connection_file(self):
        """Write connection info to JSON dict in self.connection_file."""
        if self._connection_file_written and os.path.exists(
                self.connection_file):
            return

        self.connection_file, cfg = write_connection_file(
            self.connection_file,
            transport=self.transport,
            ip=self.ip,
            key=self.session.key,
            stdin_port=self.stdin_port,
            iopub_port=self.iopub_port,
            shell_port=self.shell_port,
            hb_port=self.hb_port,
            control_port=self.control_port,
            signature_scheme=self.session.signature_scheme,
        )
        # write_connection_file also sets default ports:
        for name in port_names:
            setattr(self, name, cfg[name])

        self._connection_file_written = True

    def load_connection_file(self):
        """Load connection info from JSON dict in self.connection_file."""
        self.log.debug(u"Loading connection file %s", self.connection_file)
        with open(self.connection_file) as f:
            cfg = json.load(f)
        self.transport = cfg.get('transport', self.transport)
        self.ip = cfg.get('ip', self._ip_default())

        for name in port_names:
            if getattr(self, name) == 0 and name in cfg:
                # not overridden by config or cl_args
                setattr(self, name, cfg[name])

        if 'key' in cfg:
            self.session.key = str_to_bytes(cfg['key'])
        if 'signature_scheme' in cfg:
            self.session.signature_scheme = cfg['signature_scheme']

    #--------------------------------------------------------------------------
    # Creating connected sockets
    #--------------------------------------------------------------------------

    def _make_url(self, channel):
        """Make a ZeroMQ URL for a given channel."""
        transport = self.transport
        ip = self.ip
        port = getattr(self, '%s_port' % channel)

        if transport == 'tcp':
            return "tcp://%s:%i" % (ip, port)
        else:
            return "%s://%s-%s" % (transport, ip, port)

    def _create_connected_socket(self, channel, identity=None):
        """Create a zmq Socket and connect it to the kernel."""
        url = self._make_url(channel)
        socket_type = channel_socket_types[channel]
        self.log.debug("Connecting to: %s" % url)
        sock = self.context.socket(socket_type)
        # set linger to 1s to prevent hangs at exit
        sock.linger = 1000
        if identity:
            sock.identity = identity
        sock.connect(url)
        return sock

    def connect_iopub(self, identity=None):
        """return zmq Socket connected to the IOPub channel"""
        sock = self._create_connected_socket('iopub', identity=identity)
        sock.setsockopt(zmq.SUBSCRIBE, b'')
        return sock

    def connect_shell(self, identity=None):
        """return zmq Socket connected to the Shell channel"""
        return self._create_connected_socket('shell', identity=identity)

    def connect_stdin(self, identity=None):
        """return zmq Socket connected to the StdIn channel"""
        return self._create_connected_socket('stdin', identity=identity)

    def connect_hb(self, identity=None):
        """return zmq Socket connected to the Heartbeat channel"""
        return self._create_connected_socket('hb', identity=identity)

    def connect_control(self, identity=None):
        """return zmq Socket connected to the Control channel"""
        return self._create_connected_socket('control', identity=identity)
예제 #29
0
class Foo(Configurable):
    a = Integer(0, config=True, help="The integer a.")
    b = Unicode('nope', config=True)
예제 #30
0
class Foo(Configurable):

    i = Integer(0, config=True, help="The integer i.")
    j = Integer(1, config=True, help="The integer j.")
    name = Unicode(u'Brian', config=True, help="First name.")