Exemplo n.º 1
0
class InteractiveShellApp(Configurable):
    """A Mixin for applications that start InteractiveShell instances.
    
    Provides configurables for loading extensions and executing files
    as part of configuring a Shell environment.

    The following methods should be called by the :meth:`initialize` method
    of the subclass:

      - :meth:`init_path`
      - :meth:`init_shell` (to be implemented by the subclass)
      - :meth:`init_gui_pylab`
      - :meth:`init_extensions`
      - :meth:`init_code`
    """
    extensions = List(
        Unicode(),
        help="A list of dotted module names of IPython extensions to load."
    ).tag(config=True)
    extra_extension = Unicode(
        '', help="dotted module name of an IPython extension to load.").tag(
            config=True)

    reraise_ipython_extension_failures = Bool(
        False,
        help="Reraise exceptions encountered loading IPython extensions?",
    ).tag(config=True)

    # Extensions that are always loaded (not configurable)
    default_extensions = List(Unicode(), [u'storemagic']).tag(config=False)

    hide_initial_ns = Bool(
        True,
        help=
        """Should variables loaded at startup (by startup files, exec_lines, etc.)
        be hidden from tools like %who?""").tag(config=True)

    exec_files = List(
        Unicode(),
        help="""List of files to run at IPython startup.""").tag(config=True)
    exec_PYTHONSTARTUP = Bool(
        True,
        help="""Run the file referenced by the PYTHONSTARTUP environment
        variable at IPython startup.""").tag(config=True)
    file_to_run = Unicode('', help="""A file to be run""").tag(config=True)

    exec_lines = List(
        Unicode(),
        help="""lines of code to run at IPython startup.""").tag(config=True)
    code_to_run = Unicode(
        '', help="Execute the given command string.").tag(config=True)
    module_to_run = Unicode(
        '', help="Run the module as a script.").tag(config=True)
    gui = CaselessStrEnum(
        gui_keys,
        allow_none=True,
        help="Enable GUI event loop integration with any of {0}.".format(
            gui_keys)).tag(config=True)
    matplotlib = CaselessStrEnum(
        backend_keys,
        allow_none=True,
        help="""Configure matplotlib for interactive use with
        the default matplotlib backend.""").tag(config=True)
    pylab = CaselessStrEnum(
        backend_keys,
        allow_none=True,
        help="""Pre-load matplotlib and numpy for interactive use,
        selecting a particular matplotlib backend and loop integration.
        """).tag(config=True)
    pylab_import_all = Bool(
        True,
        help=
        """If true, IPython will populate the user namespace with numpy, pylab, etc.
        and an ``import *`` is done from numpy and pylab, when using pylab mode.
        
        When False, pylab mode should not import any names into the user namespace.
        """).tag(config=True)
    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
                     allow_none=True)
    # whether interact-loop should start
    interact = Bool(True)

    user_ns = Instance(dict, args=None, allow_none=True)

    @observe('user_ns')
    def _user_ns_changed(self, change):
        if self.shell is not None:
            self.shell.user_ns = change['new']
            self.shell.init_user_ns()

    def init_path(self):
        """Add current working directory, '', to sys.path"""
        if sys.path[0] != '':
            sys.path.insert(0, '')

    def init_shell(self):
        raise NotImplementedError("Override in subclasses")

    def init_gui_pylab(self):
        """Enable GUI event loop integration, taking pylab into account."""
        enable = False
        shell = self.shell
        if self.pylab:
            enable = lambda key: shell.enable_pylab(
                key, import_all=self.pylab_import_all)
            key = self.pylab
        elif self.matplotlib:
            enable = shell.enable_matplotlib
            key = self.matplotlib
        elif self.gui:
            enable = shell.enable_gui
            key = self.gui

        if not enable:
            return

        try:
            r = enable(key)
        except ImportError:
            self.log.warning(
                "Eventloop or matplotlib integration failed. Is matplotlib installed?"
            )
            self.shell.showtraceback()
            return
        except Exception:
            self.log.warning("GUI event loop or pylab initialization failed")
            self.shell.showtraceback()
            return

        if isinstance(r, tuple):
            gui, backend = r[:2]
            self.log.info(
                "Enabling GUI event loop integration, "
                "eventloop=%s, matplotlib=%s", gui, backend)
            if key == "auto":
                print("Using matplotlib backend: %s" % backend)
        else:
            gui = r
            self.log.info(
                "Enabling GUI event loop integration, "
                "eventloop=%s", gui)

    def init_extensions(self):
        """Load all IPython extensions in IPythonApp.extensions.

        This uses the :meth:`ExtensionManager.load_extensions` to load all
        the extensions listed in ``self.extensions``.
        """
        try:
            self.log.debug("Loading IPython extensions...")
            extensions = self.default_extensions + self.extensions
            if self.extra_extension:
                extensions.append(self.extra_extension)
            for ext in extensions:
                try:
                    self.log.info("Loading IPython extension: %s" % ext)
                    self.shell.extension_manager.load_extension(ext)
                except:
                    if self.reraise_ipython_extension_failures:
                        raise
                    msg = ("Error in loading extension: {ext}\n"
                           "Check your config files in {location}".format(
                               ext=ext, location=self.profile_dir.location))
                    self.log.warning(msg, exc_info=True)
        except:
            if self.reraise_ipython_extension_failures:
                raise
            self.log.warning("Unknown error in loading extensions:",
                             exc_info=True)

    def init_code(self):
        """run the pre-flight code, specified via exec_lines"""
        self._run_startup_files()
        self._run_exec_lines()
        self._run_exec_files()

        # Hide variables defined here from %who etc.
        if self.hide_initial_ns:
            self.shell.user_ns_hidden.update(self.shell.user_ns)

        # command-line execution (ipython -i script.py, ipython -m module)
        # should *not* be excluded from %whos
        self._run_cmd_line_code()
        self._run_module()

        # flush output, so itwon't be attached to the first cell
        sys.stdout.flush()
        sys.stderr.flush()

    def _run_exec_lines(self):
        """Run lines of code in IPythonApp.exec_lines in the user's namespace."""
        if not self.exec_lines:
            return
        try:
            self.log.debug("Running code from IPythonApp.exec_lines...")
            for line in self.exec_lines:
                try:
                    self.log.info("Running code in user namespace: %s" % line)
                    self.shell.run_cell(line, store_history=False)
                except:
                    self.log.warning("Error in executing line in user "
                                     "namespace: %s" % line)
                    self.shell.showtraceback()
        except:
            self.log.warning(
                "Unknown error in handling IPythonApp.exec_lines:")
            self.shell.showtraceback()

    def _exec_file(self, fname, shell_futures=False):
        try:
            full_filename = filefind(fname, [u'.', self.ipython_dir])
        except IOError:
            self.log.warning("File not found: %r" % fname)
            return
        # Make sure that the running script gets a proper sys.argv as if it
        # were run from a system shell.
        save_argv = sys.argv
        sys.argv = [full_filename] + self.extra_args[1:]
        # protect sys.argv from potential unicode strings on Python 2:
        if not py3compat.PY3:
            sys.argv = [py3compat.cast_bytes(a) for a in sys.argv]
        try:
            if os.path.isfile(full_filename):
                self.log.info("Running file in user namespace: %s" %
                              full_filename)
                # Ensure that __file__ is always defined to match Python
                # behavior.
                with preserve_keys(self.shell.user_ns, '__file__'):
                    self.shell.user_ns['__file__'] = fname
                    if full_filename.endswith('.ipy'):
                        self.shell.safe_execfile_ipy(
                            full_filename, shell_futures=shell_futures)
                    else:
                        # default to python, even without extension
                        self.shell.safe_execfile(full_filename,
                                                 self.shell.user_ns,
                                                 shell_futures=shell_futures,
                                                 raise_exceptions=True)
        finally:
            sys.argv = save_argv

    def _run_startup_files(self):
        """Run files from profile startup directory"""
        startup_dir = self.profile_dir.startup_dir
        startup_files = []

        if self.exec_PYTHONSTARTUP and os.environ.get('PYTHONSTARTUP', False) and \
                not (self.file_to_run or self.code_to_run or self.module_to_run):
            python_startup = os.environ['PYTHONSTARTUP']
            self.log.debug("Running PYTHONSTARTUP file %s...", python_startup)
            try:
                self._exec_file(python_startup)
            except:
                self.log.warning(
                    "Unknown error in handling PYTHONSTARTUP file %s:",
                    python_startup)
                self.shell.showtraceback()

        startup_files += glob.glob(os.path.join(startup_dir, '*.py'))
        startup_files += glob.glob(os.path.join(startup_dir, '*.ipy'))
        if not startup_files:
            return

        self.log.debug("Running startup files from %s...", startup_dir)
        try:
            for fname in sorted(startup_files):
                self._exec_file(fname)
        except:
            self.log.warning("Unknown error in handling startup files:")
            self.shell.showtraceback()

    def _run_exec_files(self):
        """Run files from IPythonApp.exec_files"""
        if not self.exec_files:
            return

        self.log.debug("Running files in IPythonApp.exec_files...")
        try:
            for fname in self.exec_files:
                self._exec_file(fname)
        except:
            self.log.warning(
                "Unknown error in handling IPythonApp.exec_files:")
            self.shell.showtraceback()

    def _run_cmd_line_code(self):
        """Run code or file specified at the command-line"""
        if self.code_to_run:
            line = self.code_to_run
            try:
                self.log.info("Running code given at command line (c=): %s" %
                              line)
                self.shell.run_cell(line, store_history=False)
            except:
                self.log.warning(
                    "Error in executing line in user namespace: %s" % line)
                self.shell.showtraceback()
                if not self.interact:
                    self.exit(1)

        # Like Python itself, ignore the second if the first of these is present
        elif self.file_to_run:
            fname = self.file_to_run
            try:
                self._exec_file(fname, shell_futures=True)
            except:
                self.shell.showtraceback(tb_offset=4)
                if not self.interact:
                    self.exit(1)

    def _run_module(self):
        """Run module specified at the command-line."""
        if self.module_to_run:
            # Make sure that the module gets a proper sys.argv as if it were
            # run using `python -m`.
            save_argv = sys.argv
            sys.argv = [sys.executable] + self.extra_args
            try:
                self.shell.safe_run_module(self.module_to_run,
                                           self.shell.user_ns)
            finally:
                sys.argv = save_argv
Exemplo n.º 2
0
class HistoryManager(HistoryAccessor):
    """A class to organize all history-related functionality in one place.
    """
    # Public interface

    # An instance of the IPython shell we are attached to
    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
                     allow_none=True)
    # Lists to hold processed and raw history. These start with a blank entry
    # so that we can index them starting from 1
    input_hist_parsed = List([""])
    input_hist_raw = List([""])
    # A list of directories visited during session
    dir_hist = List()

    def _dir_hist_default(self):
        try:
            return [py3compat.getcwd()]
        except OSError:
            return []

    # A dict of output history, keyed with ints from the shell's
    # execution count.
    output_hist = Dict()
    # The text/plain repr of outputs.
    output_hist_reprs = Dict()

    # The number of the current session in the history database
    session_number = Integer()

    db_log_output = Bool(
        False,
        config=True,
        help="Should the history database include output? (default: no)")
    db_cache_size = Integer(
        0,
        config=True,
        help=
        "Write to database every x commands (higher values save disk access & power).\n"
        "Values of 1 or less effectively disable caching.")
    # The input and output caches
    db_input_cache = List()
    db_output_cache = List()

    # History saving in separate thread
    save_thread = Instance('IPython.core.history.HistorySavingThread',
                           allow_none=True)
    try:  # Event is a function returning an instance of _Event...
        save_flag = Instance(threading._Event, allow_none=True)
    except AttributeError:  # ...until Python 3.3, when it's a class.
        save_flag = Instance(threading.Event, allow_none=True)

    # Private interface
    # Variables used to store the three last inputs from the user.  On each new
    # history update, we populate the user's namespace with these, shifted as
    # necessary.
    _i00 = Unicode(u'')
    _i = Unicode(u'')
    _ii = Unicode(u'')
    _iii = Unicode(u'')

    # A regex matching all forms of the exit command, so that we don't store
    # them in the history (it's annoying to rewind the first entry and land on
    # an exit call).
    _exit_re = re.compile(r"(exit|quit)(\s*\(.*\))?$")

    def __init__(self, shell=None, config=None, **traits):
        """Create a new history manager associated with a shell instance.
        """
        # We need a pointer back to the shell for various tasks.
        super(HistoryManager, self).__init__(shell=shell,
                                             config=config,
                                             **traits)
        self.save_flag = threading.Event()
        self.db_input_cache_lock = threading.Lock()
        self.db_output_cache_lock = threading.Lock()

        try:
            self.new_session()
        except OperationalError:
            self.log.error(
                "Failed to create history session in %s. History will not be saved.",
                self.hist_file,
                exc_info=True)
            self.hist_file = ':memory:'

        if self.enabled and self.hist_file != ':memory:':
            self.save_thread = HistorySavingThread(self)
            self.save_thread.start()

    def _get_hist_file_name(self, profile=None):
        """Get default history file name based on the Shell's profile.
        
        The profile parameter is ignored, but must exist for compatibility with
        the parent class."""
        profile_dir = self.shell.profile_dir.location
        return os.path.join(profile_dir, 'history.sqlite')

    @needs_sqlite
    def new_session(self, conn=None):
        """Get a new session number."""
        if conn is None:
            conn = self.db

        with conn:
            cur = conn.execute(
                """INSERT INTO sessions VALUES (NULL, ?, NULL,
                            NULL, "") """, (datetime.datetime.now(), ))
            self.session_number = cur.lastrowid

    def end_session(self):
        """Close the database session, filling in the end time and line count."""
        self.writeout_cache()
        with self.db:
            self.db.execute(
                """UPDATE sessions SET end=?, num_cmds=? WHERE
                            session==?""",
                (datetime.datetime.now(), len(self.input_hist_parsed) - 1,
                 self.session_number))
        self.session_number = 0

    def name_session(self, name):
        """Give the current session a name in the history database."""
        with self.db:
            self.db.execute("UPDATE sessions SET remark=? WHERE session==?",
                            (name, self.session_number))

    def reset(self, new_session=True):
        """Clear the session history, releasing all object references, and
        optionally open a new session."""
        self.output_hist.clear()
        # The directory history can't be completely empty
        self.dir_hist[:] = [py3compat.getcwd()]

        if new_session:
            if self.session_number:
                self.end_session()
            self.input_hist_parsed[:] = [""]
            self.input_hist_raw[:] = [""]
            self.new_session()

    # ------------------------------
    # Methods for retrieving history
    # ------------------------------
    def get_session_info(self, session=0):
        """Get info about a session.

        Parameters
        ----------

        session : int
            Session number to retrieve. The current session is 0, and negative
            numbers count back from current session, so -1 is the previous session.

        Returns
        -------
        
        session_id : int
           Session ID number
        start : datetime
           Timestamp for the start of the session.
        end : datetime
           Timestamp for the end of the session, or None if IPython crashed.
        num_cmds : int
           Number of commands run, or None if IPython crashed.
        remark : unicode
           A manually set description.
        """
        if session <= 0:
            session += self.session_number

        return super(HistoryManager, self).get_session_info(session=session)

    def _get_range_session(self, start=1, stop=None, raw=True, output=False):
        """Get input and output history from the current session. Called by
        get_range, and takes similar parameters."""
        input_hist = self.input_hist_raw if raw else self.input_hist_parsed

        n = len(input_hist)
        if start < 0:
            start += n
        if not stop or (stop > n):
            stop = n
        elif stop < 0:
            stop += n

        for i in range(start, stop):
            if output:
                line = (input_hist[i], self.output_hist_reprs.get(i))
            else:
                line = input_hist[i]
            yield (0, i, line)

    def get_range(self, session=0, start=1, stop=None, raw=True, output=False):
        """Retrieve input by session.
        
        Parameters
        ----------
        session : int
            Session number to retrieve. The current session is 0, and negative
            numbers count back from current session, so -1 is previous session.
        start : int
            First line to retrieve.
        stop : int
            End of line range (excluded from output itself). If None, retrieve
            to the end of the session.
        raw : bool
            If True, return untranslated input
        output : bool
            If True, attempt to include output. This will be 'real' Python
            objects for the current session, or text reprs from previous
            sessions if db_log_output was enabled at the time. Where no output
            is found, None is used.
            
        Returns
        -------
        entries
          An iterator over the desired lines. Each line is a 3-tuple, either
          (session, line, input) if output is False, or
          (session, line, (input, output)) if output is True.
        """
        if session <= 0:
            session += self.session_number
        if session == self.session_number:  # Current session
            return self._get_range_session(start, stop, raw, output)
        return super(HistoryManager, self).get_range(session, start, stop, raw,
                                                     output)

    ## ----------------------------
    ## Methods for storing history:
    ## ----------------------------
    def store_inputs(self, line_num, source, source_raw=None):
        """Store source and raw input in history and create input cache
        variables ``_i*``.

        Parameters
        ----------
        line_num : int
          The prompt number of this input.

        source : str
          Python input.

        source_raw : str, optional
          If given, this is the raw input without any IPython transformations
          applied to it.  If not given, ``source`` is used.
        """
        if source_raw is None:
            source_raw = source
        source = source.rstrip('\n')
        source_raw = source_raw.rstrip('\n')

        # do not store exit/quit commands
        if self._exit_re.match(source_raw.strip()):
            return

        self.input_hist_parsed.append(source)
        self.input_hist_raw.append(source_raw)

        with self.db_input_cache_lock:
            self.db_input_cache.append((line_num, source, source_raw))
            # Trigger to flush cache and write to DB.
            if len(self.db_input_cache) >= self.db_cache_size:
                self.save_flag.set()

        # update the auto _i variables
        self._iii = self._ii
        self._ii = self._i
        self._i = self._i00
        self._i00 = source_raw

        # hackish access to user namespace to create _i1,_i2... dynamically
        new_i = '_i%s' % line_num
        to_main = {
            '_i': self._i,
            '_ii': self._ii,
            '_iii': self._iii,
            new_i: self._i00
        }

        if self.shell is not None:
            self.shell.push(to_main, interactive=False)

    def store_output(self, line_num):
        """If database output logging is enabled, this saves all the
        outputs from the indicated prompt number to the database. It's
        called by run_cell after code has been executed.

        Parameters
        ----------
        line_num : int
          The line number from which to save outputs
        """
        if (not self.db_log_output) or (line_num
                                        not in self.output_hist_reprs):
            return
        output = self.output_hist_reprs[line_num]

        with self.db_output_cache_lock:
            self.db_output_cache.append((line_num, output))
        if self.db_cache_size <= 1:
            self.save_flag.set()

    def _writeout_input_cache(self, conn):
        with conn:
            for line in self.db_input_cache:
                conn.execute("INSERT INTO history VALUES (?, ?, ?, ?)",
                             (self.session_number, ) + line)

    def _writeout_output_cache(self, conn):
        with conn:
            for line in self.db_output_cache:
                conn.execute("INSERT INTO output_history VALUES (?, ?, ?)",
                             (self.session_number, ) + line)

    @needs_sqlite
    def writeout_cache(self, conn=None):
        """Write any entries in the cache to the database."""
        if conn is None:
            conn = self.db

        with self.db_input_cache_lock:
            try:
                self._writeout_input_cache(conn)
            except sqlite3.IntegrityError:
                self.new_session(conn)
                print("ERROR! Session/line number was not unique in",
                      "database. History logging moved to new session",
                      self.session_number)
                try:
                    # Try writing to the new session. If this fails, don't
                    # recurse
                    self._writeout_input_cache(conn)
                except sqlite3.IntegrityError:
                    pass
            finally:
                self.db_input_cache = []

        with self.db_output_cache_lock:
            try:
                self._writeout_output_cache(conn)
            except sqlite3.IntegrityError:
                print("!! Session/line number for output was not unique",
                      "in database. Output will not be stored.")
            finally:
                self.db_output_cache = []
Exemplo n.º 3
0
class GeoJSON(FeatureGroup):
    _view_name = Unicode('LeafletGeoJSONView').tag(sync=True)
    _model_name = Unicode('LeafletGeoJSONModel').tag(sync=True)

    data = Dict().tag(sync=True)
    style = Dict().tag(sync=True)
    hover_style = Dict().tag(sync=True)
    point_style = Dict().tag(sync=True)
    style_callback = Any()

    _click_callbacks = Instance(CallbackDispatcher, ())
    _hover_callbacks = Instance(CallbackDispatcher, ())

    @validate('style_callback')
    def _validate_style_callback(self, proposal):
        if not callable(proposal.value):
            raise TraitError(
                'style_callback should be callable (functor/function/lambda)')
        return proposal.value

    @observe('data', 'style', 'style_callback')
    def _update_data(self, change):
        self.data = self._get_data()

    def _get_data(self):
        if 'type' not in self.data:
            # We can't apply a style we don't know what the data look like
            return self.data

        datatype = self.data['type']

        style_callback = None
        if self.style_callback:
            style_callback = self.style_callback
        elif self.style:
            style_callback = lambda feature: self.style
        else:
            # No style to apply
            return self.data

        if datatype == 'Feature':
            self._apply_style(self.data, style_callback)
        elif datatype == 'FeatureCollection':
            for feature in self.data['features']:
                self._apply_style(feature, style_callback)

        return self.data

    def _apply_style(self, feature, style_callback):
        if 'properties' not in feature:
            feature['properties'] = {}

        properties = feature['properties']
        if 'style' in properties:
            properties['style'].update(style_callback(feature))
        else:
            properties['style'] = style_callback(feature)

    def __init__(self, **kwargs):
        super(GeoJSON, self).__init__(**kwargs)
        self.on_msg(self._handle_m_msg)
        self.data = self._get_data()

    def _handle_m_msg(self, _, content, buffers):
        if content.get('event', '') == 'click':
            self._click_callbacks(**content)
        if content.get('event', '') == 'mouseover':
            self._hover_callbacks(**content)

    def on_click(self, callback, remove=False):
        '''
        The click callback takes an unpacked set of keyword arguments.
        '''
        self._click_callbacks.register_callback(callback, remove=remove)

    def on_hover(self, callback, remove=False):
        '''
        The hover callback takes an unpacked set of keyword arguments.
        '''
        self._hover_callbacks.register_callback(callback, remove=remove)
Exemplo n.º 4
0
class IPythonKernel(KernelBase):
    shell = Instance("IPython.core.interactiveshell.InteractiveShellABC",
                     allow_none=True)
    shell_class = Type(ZMQInteractiveShell)

    use_experimental_completions = Bool(
        True,
        help=
        "Set this flag to False to deactivate the use of experimental IPython completion APIs.",
    ).tag(config=True)

    debugpy_stream = Instance(
        ZMQStream, allow_none=True) if _is_debugpy_available else None

    user_module = Any()

    @observe("user_module")
    @observe_compat
    def _user_module_changed(self, change):
        if self.shell is not None:
            self.shell.user_module = change["new"]

    user_ns = Instance(dict, args=None, allow_none=True)

    @observe("user_ns")
    @observe_compat
    def _user_ns_changed(self, change):
        if self.shell is not None:
            self.shell.user_ns = change["new"]
            self.shell.init_user_ns()

    # A reference to the Python builtin 'raw_input' function.
    # (i.e., __builtin__.raw_input for Python 2.7, builtins.input for Python 3)
    _sys_raw_input = Any()
    _sys_eval_input = Any()

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # Initialize the Debugger
        if _is_debugpy_available:
            self.debugger = Debugger(
                self.log,
                self.debugpy_stream,
                self._publish_debug_event,
                self.debug_shell_socket,
                self.session,
                self.debug_just_my_code,
            )

        # Initialize the InteractiveShell subclass
        self.shell = self.shell_class.instance(
            parent=self,
            profile_dir=self.profile_dir,
            user_module=self.user_module,
            user_ns=self.user_ns,
            kernel=self,
            compiler_class=XCachingCompiler,
        )
        self.shell.displayhook.session = self.session
        self.shell.displayhook.pub_socket = self.iopub_socket
        self.shell.displayhook.topic = self._topic("execute_result")
        self.shell.display_pub.session = self.session
        self.shell.display_pub.pub_socket = self.iopub_socket

        self.comm_manager = CommManager(parent=self, kernel=self)

        self.shell.configurables.append(self.comm_manager)
        comm_msg_types = ["comm_open", "comm_msg", "comm_close"]
        for msg_type in comm_msg_types:
            self.shell_handlers[msg_type] = getattr(self.comm_manager,
                                                    msg_type)

        if _use_appnope() and self._darwin_app_nap:
            # Disable app-nap as the kernel is not a gui but can have guis
            import appnope

            appnope.nope()

    help_links = List([
        {
            "text": "Python Reference",
            "url": "https://docs.python.org/%i.%i" % sys.version_info[:2],
        },
        {
            "text": "IPython Reference",
            "url": "https://ipython.org/documentation.html",
        },
        {
            "text": "NumPy Reference",
            "url": "https://docs.scipy.org/doc/numpy/reference/",
        },
        {
            "text": "SciPy Reference",
            "url": "https://docs.scipy.org/doc/scipy/reference/",
        },
        {
            "text": "Matplotlib Reference",
            "url": "https://matplotlib.org/contents.html",
        },
        {
            "text": "SymPy Reference",
            "url": "http://docs.sympy.org/latest/index.html",
        },
        {
            "text": "pandas Reference",
            "url": "https://pandas.pydata.org/pandas-docs/stable/",
        },
    ]).tag(config=True)

    # Kernel info fields
    implementation = "ipython"
    implementation_version = release.version
    language_info = {
        "name": "python",
        "version": sys.version.split()[0],
        "mimetype": "text/x-python",
        "codemirror_mode": {
            "name": "ipython",
            "version": sys.version_info[0]
        },
        "pygments_lexer": "ipython%d" % 3,
        "nbconvert_exporter": "python",
        "file_extension": ".py",
    }

    def dispatch_debugpy(self, msg):
        if _is_debugpy_available:
            # The first frame is the socket id, we can drop it
            frame = msg[1].bytes.decode("utf-8")
            self.log.debug("Debugpy received: %s", frame)
            self.debugger.tcp_client.receive_dap_frame(frame)

    @property
    def banner(self):
        return self.shell.banner

    async def poll_stopped_queue(self):
        while True:
            await self.debugger.handle_stopped_event()

    def start(self):
        self.shell.exit_now = False
        if self.debugpy_stream is None:
            self.log.warning(
                "debugpy_stream undefined, debugging will not be enabled")
        else:
            self.debugpy_stream.on_recv(self.dispatch_debugpy, copy=False)
        super().start()
        if self.debugpy_stream:
            asyncio.run_coroutine_threadsafe(
                self.poll_stopped_queue(),
                self.control_thread.io_loop.asyncio_loop)

    def set_parent(self, ident, parent, channel="shell"):
        """Overridden from parent to tell the display hook and output streams
        about the parent message.
        """
        super().set_parent(ident, parent, channel)
        if channel == "shell":
            self.shell.set_parent(parent)

    def init_metadata(self, parent):
        """Initialize metadata.

        Run at the beginning of each execution request.
        """
        md = super().init_metadata(parent)
        # FIXME: remove deprecated ipyparallel-specific code
        # This is required for ipyparallel < 5.0
        md.update({
            "dependencies_met": True,
            "engine": self.ident,
        })
        return md

    def finish_metadata(self, parent, metadata, reply_content):
        """Finish populating metadata.

        Run after completing an execution request.
        """
        # FIXME: remove deprecated ipyparallel-specific code
        # This is required by ipyparallel < 5.0
        metadata["status"] = reply_content["status"]
        if reply_content["status"] == "error" and reply_content[
                "ename"] == "UnmetDependency":
            metadata["dependencies_met"] = False

        return metadata

    def _forward_input(self, allow_stdin=False):
        """Forward raw_input and getpass to the current frontend.

        via input_request
        """
        self._allow_stdin = allow_stdin

        self._sys_raw_input = builtins.input
        builtins.input = self.raw_input

        self._save_getpass = getpass.getpass
        getpass.getpass = self.getpass

    def _restore_input(self):
        """Restore raw_input, getpass"""
        builtins.input = self._sys_raw_input

        getpass.getpass = self._save_getpass

    @property
    def execution_count(self):
        return self.shell.execution_count

    @execution_count.setter
    def execution_count(self, value):
        # Ignore the incrementing done by KernelBase, in favour of our shell's
        # execution counter.
        pass

    @contextmanager
    def _cancel_on_sigint(self, future):
        """ContextManager for capturing SIGINT and cancelling a future

        SIGINT raises in the event loop when running async code,
        but we want it to halt a coroutine.

        Ideally, it would raise KeyboardInterrupt,
        but this turns it into a CancelledError.
        At least it gets a decent traceback to the user.
        """
        sigint_future = asyncio.Future()

        # whichever future finishes first,
        # cancel the other one
        def cancel_unless_done(f, _ignored):
            if f.cancelled() or f.done():
                return
            f.cancel()

        # when sigint finishes,
        # abort the coroutine with CancelledError
        sigint_future.add_done_callback(partial(cancel_unless_done, future))
        # when the main future finishes,
        # stop watching for SIGINT events
        future.add_done_callback(partial(cancel_unless_done, sigint_future))

        def handle_sigint(*args):
            def set_sigint_result():
                if sigint_future.cancelled() or sigint_future.done():
                    return
                sigint_future.set_result(1)

            # use add_callback for thread safety
            self.io_loop.add_callback(set_sigint_result)

        # set the custom sigint hander during this context
        save_sigint = signal.signal(signal.SIGINT, handle_sigint)
        try:
            yield
        finally:
            # restore the previous sigint handler
            signal.signal(signal.SIGINT, save_sigint)

    async def do_execute(
        self,
        code,
        silent,
        store_history=True,
        user_expressions=None,
        allow_stdin=False,
        *,
        cell_id=None,
    ):
        shell = self.shell  # we'll need this a lot here

        self._forward_input(allow_stdin)

        reply_content = {}
        if hasattr(shell, "run_cell_async") and hasattr(
                shell, "should_run_async"):
            run_cell = shell.run_cell_async
            should_run_async = shell.should_run_async
            with_cell_id = _accepts_cell_id(run_cell)
        else:
            should_run_async = lambda cell: False  # noqa

            # older IPython,
            # use blocking run_cell and wrap it in coroutine
            async def run_cell(*args, **kwargs):
                return shell.run_cell(*args, **kwargs)

            with_cell_id = _accepts_cell_id(shell.run_cell)
        try:

            # default case: runner is asyncio and asyncio is already running
            # TODO: this should check every case for "are we inside the runner",
            # not just asyncio
            preprocessing_exc_tuple = None
            try:
                transformed_cell = self.shell.transform_cell(code)
            except Exception:
                transformed_cell = code
                preprocessing_exc_tuple = sys.exc_info()

            if (_asyncio_runner and shell.loop_runner is _asyncio_runner
                    and asyncio.get_event_loop().is_running()
                    and should_run_async(
                        code,
                        transformed_cell=transformed_cell,
                        preprocessing_exc_tuple=preprocessing_exc_tuple,
                    )):
                if with_cell_id:
                    coro = run_cell(
                        code,
                        store_history=store_history,
                        silent=silent,
                        transformed_cell=transformed_cell,
                        preprocessing_exc_tuple=preprocessing_exc_tuple,
                        cell_id=cell_id,
                    )
                else:
                    coro = run_cell(
                        code,
                        store_history=store_history,
                        silent=silent,
                        transformed_cell=transformed_cell,
                        preprocessing_exc_tuple=preprocessing_exc_tuple,
                    )

                coro_future = asyncio.ensure_future(coro)

                with self._cancel_on_sigint(coro_future):
                    res = None
                    try:
                        res = await coro_future
                    finally:
                        shell.events.trigger("post_execute")
                        if not silent:
                            shell.events.trigger("post_run_cell", res)
            else:
                # runner isn't already running,
                # make synchronous call,
                # letting shell dispatch to loop runners
                if with_cell_id:
                    res = shell.run_cell(
                        code,
                        store_history=store_history,
                        silent=silent,
                        cell_id=cell_id,
                    )
                else:
                    res = shell.run_cell(code,
                                         store_history=store_history,
                                         silent=silent)
        finally:
            self._restore_input()

        if res.error_before_exec is not None:
            err = res.error_before_exec
        else:
            err = res.error_in_exec

        if res.success:
            reply_content["status"] = "ok"
        else:
            reply_content["status"] = "error"

            reply_content.update({
                "traceback": shell._last_traceback or [],
                "ename": str(type(err).__name__),
                "evalue": str(err),
            })

            # FIXME: deprecated piece for ipyparallel (remove in 5.0):
            e_info = dict(engine_uuid=self.ident,
                          engine_id=self.int_id,
                          method="execute")
            reply_content["engine_info"] = e_info

        # Return the execution counter so clients can display prompts
        reply_content["execution_count"] = shell.execution_count - 1

        if "traceback" in reply_content:
            self.log.info(
                "Exception in execute request:\n%s",
                "\n".join(reply_content["traceback"]),
            )

        # At this point, we can tell whether the main code execution succeeded
        # or not.  If it did, we proceed to evaluate user_expressions
        if reply_content["status"] == "ok":
            reply_content["user_expressions"] = shell.user_expressions(
                user_expressions or {})
        else:
            # If there was an error, don't even try to compute expressions
            reply_content["user_expressions"] = {}

        # Payloads should be retrieved regardless of outcome, so we can both
        # recover partial output (that could have been generated early in a
        # block, before an error) and always clear the payload system.
        reply_content["payload"] = shell.payload_manager.read_payload()
        # Be aggressive about clearing the payload because we don't want
        # it to sit in memory until the next execute_request comes in.
        shell.payload_manager.clear_payload()

        return reply_content

    def do_complete(self, code, cursor_pos):
        if _use_experimental_60_completion and self.use_experimental_completions:
            return self._experimental_do_complete(code, cursor_pos)

        # FIXME: IPython completers currently assume single line,
        # but completion messages give multi-line context
        # For now, extract line from cell, based on cursor_pos:
        if cursor_pos is None:
            cursor_pos = len(code)
        line, offset = line_at_cursor(code, cursor_pos)
        line_cursor = cursor_pos - offset

        txt, matches = self.shell.complete("", line, line_cursor)
        return {
            "matches": matches,
            "cursor_end": cursor_pos,
            "cursor_start": cursor_pos - len(txt),
            "metadata": {},
            "status": "ok",
        }

    async def do_debug_request(self, msg):
        if _is_debugpy_available:
            return await self.debugger.process_request(msg)

    def _experimental_do_complete(self, code, cursor_pos):
        """
        Experimental completions from IPython, using Jedi.
        """
        if cursor_pos is None:
            cursor_pos = len(code)
        with _provisionalcompleter():
            raw_completions = self.shell.Completer.completions(
                code, cursor_pos)
            completions = list(_rectify_completions(code, raw_completions))

            comps = []
            for comp in completions:
                comps.append(
                    dict(
                        start=comp.start,
                        end=comp.end,
                        text=comp.text,
                        type=comp.type,
                        signature=comp.signature,
                    ))

        if completions:
            s = completions[0].start
            e = completions[0].end
            matches = [c.text for c in completions]
        else:
            s = cursor_pos
            e = cursor_pos
            matches = []

        return {
            "matches": matches,
            "cursor_end": e,
            "cursor_start": s,
            "metadata": {
                _EXPERIMENTAL_KEY_NAME: comps
            },
            "status": "ok",
        }

    def do_inspect(self, code, cursor_pos, detail_level=0, omit_sections=()):
        name = token_at_cursor(code, cursor_pos)

        reply_content = {"status": "ok"}
        reply_content["data"] = {}
        reply_content["metadata"] = {}
        try:
            if release.version_info >= (8, ):
                # `omit_sections` keyword will be available in IPython 8, see
                # https://github.com/ipython/ipython/pull/13343
                bundle = self.shell.object_inspect_mime(
                    name,
                    detail_level=detail_level,
                    omit_sections=omit_sections,
                )
            else:
                bundle = self.shell.object_inspect_mime(
                    name, detail_level=detail_level)
            reply_content["data"].update(bundle)
            if not self.shell.enable_html_pager:
                reply_content["data"].pop("text/html")
            reply_content["found"] = True
        except KeyError:
            reply_content["found"] = False

        return reply_content

    def do_history(
        self,
        hist_access_type,
        output,
        raw,
        session=0,
        start=0,
        stop=None,
        n=None,
        pattern=None,
        unique=False,
    ):
        if hist_access_type == "tail":
            hist = self.shell.history_manager.get_tail(n,
                                                       raw=raw,
                                                       output=output,
                                                       include_latest=True)

        elif hist_access_type == "range":
            hist = self.shell.history_manager.get_range(session,
                                                        start,
                                                        stop,
                                                        raw=raw,
                                                        output=output)

        elif hist_access_type == "search":
            hist = self.shell.history_manager.search(pattern,
                                                     raw=raw,
                                                     output=output,
                                                     n=n,
                                                     unique=unique)
        else:
            hist = []

        return {
            "status": "ok",
            "history": list(hist),
        }

    def do_shutdown(self, restart):
        self.shell.exit_now = True
        return dict(status="ok", restart=restart)

    def do_is_complete(self, code):
        transformer_manager = getattr(self.shell, "input_transformer_manager",
                                      None)
        if transformer_manager is None:
            # input_splitter attribute is deprecated
            transformer_manager = self.shell.input_splitter
        status, indent_spaces = transformer_manager.check_complete(code)
        r = {"status": status}
        if status == "incomplete":
            r["indent"] = " " * indent_spaces
        return r

    def do_apply(self, content, bufs, msg_id, reply_metadata):
        from .serialize import serialize_object, unpack_apply_message

        shell = self.shell
        try:
            working = shell.user_ns

            prefix = "_" + str(msg_id).replace("-", "") + "_"

            f, args, kwargs = unpack_apply_message(bufs, working, copy=False)

            fname = getattr(f, "__name__", "f")

            fname = prefix + "f"
            argname = prefix + "args"
            kwargname = prefix + "kwargs"
            resultname = prefix + "result"

            ns = {fname: f, argname: args, kwargname: kwargs, resultname: None}
            # print ns
            working.update(ns)
            code = "%s = %s(*%s,**%s)" % (resultname, fname, argname,
                                          kwargname)
            try:
                exec(code, shell.user_global_ns, shell.user_ns)
                result = working.get(resultname)
            finally:
                for key in ns:
                    working.pop(key)

            result_buf = serialize_object(
                result,
                buffer_threshold=self.session.buffer_threshold,
                item_threshold=self.session.item_threshold,
            )

        except BaseException as e:
            # invoke IPython traceback formatting
            shell.showtraceback()
            reply_content = {
                "traceback": shell._last_traceback or [],
                "ename": str(type(e).__name__),
                "evalue": str(e),
            }
            # FIXME: deprecated piece for ipyparallel (remove in 5.0):
            e_info = dict(engine_uuid=self.ident,
                          engine_id=self.int_id,
                          method="apply")
            reply_content["engine_info"] = e_info

            self.send_response(
                self.iopub_socket,
                "error",
                reply_content,
                ident=self._topic("error"),
                channel="shell",
            )
            self.log.info("Exception in apply request:\n%s",
                          "\n".join(reply_content["traceback"]))
            result_buf = []
            reply_content["status"] = "error"
        else:
            reply_content = {"status": "ok"}

        return reply_content, result_buf

    def do_clear(self):
        self.shell.reset(False)
        return dict(status="ok")
Exemplo n.º 5
0
class TruncateDataWidget(SimpleWidget):
    d = Instance(DataInstance).tag(sync=True,
                                   to_json=bytes_serializer,
                                   from_json=truncate_deserializer)
Exemplo n.º 6
0
class NbConvertApp(JupyterApp):
    """Application used to convert from notebook file type (``*.ipynb``)"""
    
    version = __version__
    name = 'jupyter-nbconvert'
    aliases = nbconvert_aliases
    flags = nbconvert_flags
    
    def _log_level_default(self):
        return logging.INFO
    
    classes = List()
    def _classes_default(self):
        classes = [NbConvertBase]
        for pkg in (exporters, preprocessors, writers, postprocessors):
            for name in dir(pkg):
                cls = getattr(pkg, name)
                if isinstance(cls, type) and issubclass(cls, Configurable):
                    classes.append(cls)
        
        return classes

    description = Unicode(
        u"""This application is used to convert notebook files (*.ipynb)
        to various other formats.

        WARNING: THE COMMANDLINE INTERFACE MAY CHANGE IN FUTURE RELEASES.""")

    output_base = Unicode('', config=True, help='''overwrite base name use for output files.
            can only be used when converting one notebook at a time.
            ''')

    use_output_suffix = Bool(
        True, 
        config=True,
        help="""Whether to apply a suffix prior to the extension (only relevant
            when converting to notebook format). The suffix is determined by
            the exporter, and is usually '.nbconvert'.""")

    examples = Unicode(u"""
        The simplest way to use nbconvert is
        
        > jupyter nbconvert mynotebook.ipynb
        
        which will convert mynotebook.ipynb to the default format (probably HTML).
        
        You can specify the export format with `--to`.
        Options include {0}
        
        > jupyter nbconvert --to latex mynotebook.ipynb

        Both HTML and LaTeX support multiple output templates. LaTeX includes
        'base', 'article' and 'report'.  HTML includes 'basic' and 'full'. You
        can specify the flavor of the format used.

        > jupyter nbconvert --to html --template basic mynotebook.ipynb
        
        You can also pipe the output to stdout, rather than a file
        
        > jupyter nbconvert mynotebook.ipynb --stdout

        PDF is generated via latex

        > jupyter nbconvert mynotebook.ipynb --to pdf
        
        You can get (and serve) a Reveal.js-powered slideshow
        
        > jupyter nbconvert myslides.ipynb --to slides --post serve
        
        Multiple notebooks can be given at the command line in a couple of 
        different ways:
  
        > jupyter nbconvert notebook*.ipynb
        > jupyter nbconvert notebook1.ipynb notebook2.ipynb
        
        or you can specify the notebooks list in a config file, containing::
        
            c.NbConvertApp.notebooks = ["my_notebook.ipynb"]
        
        > jupyter nbconvert --config mycfg.py
        """.format(get_export_names()))

    # Writer specific variables
    writer = Instance('nbconvert.writers.base.WriterBase',  
                      help="""Instance of the writer class used to write the 
                      results of the conversion.""", allow_none=True)
    writer_class = DottedObjectName('FilesWriter', config=True, 
                                    help="""Writer class used to write the 
                                    results of the conversion""")
    writer_aliases = {'fileswriter': 'nbconvert.writers.files.FilesWriter',
                      'debugwriter': 'nbconvert.writers.debug.DebugWriter',
                      'stdoutwriter': 'nbconvert.writers.stdout.StdoutWriter'}
    writer_factory = Type(allow_none=True)

    def _writer_class_changed(self, name, old, new):
        if new.lower() in self.writer_aliases:
            new = self.writer_aliases[new.lower()]
        self.writer_factory = import_item(new)

    # Post-processor specific variables
    postprocessor = Instance('nbconvert.postprocessors.base.PostProcessorBase',  
                      help="""Instance of the PostProcessor class used to write the 
                      results of the conversion.""", allow_none=True)

    postprocessor_class = DottedOrNone(config=True, 
                                    help="""PostProcessor class used to write the 
                                    results of the conversion""")
    postprocessor_aliases = {'serve': 'nbconvert.postprocessors.serve.ServePostProcessor'}
    postprocessor_factory = Type(None, allow_none=True)

    def _postprocessor_class_changed(self, name, old, new):
        if new.lower() in self.postprocessor_aliases:
            new = self.postprocessor_aliases[new.lower()]
        if new:
            self.postprocessor_factory = import_item(new)


    # Other configurable variables
    export_format = CaselessStrEnum(get_export_names(),
        default_value="html",
        config=True,
        help="""The export format to be used."""
    )

    notebooks = List([], config=True, help="""List of notebooks to convert.
                     Wildcards are supported.
                     Filenames passed positionally will be added to the list.
                     """)

    @catch_config_error
    def initialize(self, argv=None):
        self.init_syspath()
        super(NbConvertApp, self).initialize(argv)
        self.init_notebooks()
        self.init_writer()
        self.init_postprocessor()



    def init_syspath(self):
        """
        Add the cwd to the sys.path ($PYTHONPATH)
        """
        sys.path.insert(0, os.getcwd())
        

    def init_notebooks(self):
        """Construct the list of notebooks.
        If notebooks are passed on the command-line,
        they override notebooks specified in config files.
        Glob each notebook to replace notebook patterns with filenames.
        """

        # Specifying notebooks on the command-line overrides (rather than adds)
        # the notebook list
        if self.extra_args:
            patterns = self.extra_args
        else:
            patterns = self.notebooks

        # Use glob to replace all the notebook patterns with filenames.
        filenames = []
        for pattern in patterns:
            
            # Use glob to find matching filenames.  Allow the user to convert 
            # notebooks without having to type the extension.
            globbed_files = glob.glob(pattern)
            globbed_files.extend(glob.glob(pattern + '.ipynb'))
            if not globbed_files:
                self.log.warn("pattern %r matched no files", pattern)

            for filename in globbed_files:
                if not filename in filenames:
                    filenames.append(filename)
        self.notebooks = filenames

    def init_writer(self):
        """
        Initialize the writer (which is stateless)
        """
        self._writer_class_changed(None, self.writer_class, self.writer_class)
        self.writer = self.writer_factory(parent=self)
        if hasattr(self.writer, 'build_directory') and self.writer.build_directory != '':
            self.use_output_suffix = False

    def init_postprocessor(self):
        """
        Initialize the postprocessor (which is stateless)
        """
        self._postprocessor_class_changed(None, self.postprocessor_class, 
            self.postprocessor_class)
        if self.postprocessor_factory:
            self.postprocessor = self.postprocessor_factory(parent=self)

    def start(self):
        """
        Ran after initialization completed
        """
        super(NbConvertApp, self).start()
        self.convert_notebooks()

    def init_single_notebook_resources(self, notebook_filename):
        """Step 1: Initialize resources

        This intializes the resources dictionary for a single notebook. This
        method should return the resources dictionary, and MUST include the
        following keys:

            - config_dir: the location of the Jupyter config directory
            - unique_key: the notebook name
            - output_files_dir: a directory where output files (not including
              the notebook itself) should be saved

        """

        # Get a unique key for the notebook and set it in the resources object.
        basename = os.path.basename(notebook_filename)
        notebook_name = basename[:basename.rfind('.')]
        if self.output_base:
            # strip duplicate extension from output_base, to avoid Basname.ext.ext
            if getattr(self.exporter, 'file_extension', False):
                base, ext = os.path.splitext(self.output_base)
                if ext == self.exporter.file_extension:
                    self.output_base = base
            notebook_name = self.output_base

        self.log.debug("Notebook name is '%s'", notebook_name)

        # first initialize the resources we want to use
        resources = {}
        resources['config_dir'] = self.config_dir
        resources['unique_key'] = notebook_name
        resources['output_files_dir'] = '%s_files' % notebook_name

        return resources

    def export_single_notebook(self, notebook_filename, resources):
        """Step 2: Export the notebook

        Exports the notebook to a particular format according to the specified
        exporter. This function returns the output and (possibly modified)
        resources from the exporter.

        """
        try:
            output, resources = self.exporter.from_filename(notebook_filename, resources=resources)
        except ConversionException:
            self.log.error("Error while converting '%s'", notebook_filename, exc_info=True)
            self.exit(1)

        return output, resources

    def write_single_notebook(self, output, resources):
        """Step 3: Write the notebook to file

        This writes output from the exporter to file using the specified writer.
        It returns the results from the writer.

        """
        if 'unique_key' not in resources:
            raise KeyError("unique_key MUST be specified in the resources, but it is not")

        notebook_name = resources['unique_key']
        if self.use_output_suffix and not self.output_base:
            notebook_name += resources.get('output_suffix', '')

        write_results = self.writer.write(
            output, resources, notebook_name=notebook_name)
        return write_results

    def postprocess_single_notebook(self, write_results):
        """Step 4: Postprocess the notebook

        This postprocesses the notebook after it has been written, taking as an
        argument the results of writing the notebook to file. This only actually
        does anything if a postprocessor has actually been specified.

        """
        # Post-process if post processor has been defined.
        if hasattr(self, 'postprocessor') and self.postprocessor:
            self.postprocessor(write_results)

    def convert_single_notebook(self, notebook_filename):
        """Convert a single notebook. Performs the following steps:

            1. Initialize notebook resources
            2. Export the notebook to a particular format
            3. Write the exported notebook to file
            4. (Maybe) postprocess the written file

        """
        self.log.info("Converting notebook %s to %s", notebook_filename, self.export_format)
        resources = self.init_single_notebook_resources(notebook_filename)
        output, resources = self.export_single_notebook(notebook_filename, resources)
        write_results = self.write_single_notebook(output, resources)
        self.postprocess_single_notebook(write_results)

    def convert_notebooks(self):
        """
        Convert the notebooks in the self.notebook traitlet
        """
        # check that the output base isn't specified if there is more than
        # one notebook to convert
        if self.output_base != '' and len(self.notebooks) > 1:
            self.log.error(
                """
                UsageError: --output flag or `NbConvertApp.output_base` config option
                cannot be used when converting multiple notebooks.
                """
            )
            self.exit(1)
        
        # initialize the exporter
        self.exporter = exporter_map[self.export_format](config=self.config)

        # no notebooks to convert!
        if len(self.notebooks) == 0:
            self.print_help()
            sys.exit(-1)

        # convert each notebook
        for notebook_filename in self.notebooks:
            self.convert_single_notebook(notebook_filename)
Exemplo n.º 7
0
class IPKernelApp(BaseIPythonApplication, InteractiveShellApp,
        ConnectionFileMixin):
    name='ipython-kernel'
    aliases = Dict(kernel_aliases)
    flags = Dict(kernel_flags)
    classes = [IPythonKernel, ZMQInteractiveShell, ProfileDir, Session]
    # the kernel class, as an importstring
    kernel_class = Type('ipykernel.ipkernel.IPythonKernel',
                        klass='ipykernel.kernelbase.Kernel',
    help="""The Kernel subclass to be used.

    This should allow easy re-use of the IPKernelApp entry point
    to configure and launch kernels other than IPython's own.
    """).tag(config=True)
    kernel = Any()
    poller = Any() # don't restrict this even though current pollers are all Threads
    heartbeat = Instance(Heartbeat, allow_none=True)
    ports = Dict()

    subcommands = {
        'install': (
            'ipykernel.kernelspec.InstallIPythonKernelSpecApp',
            'Install the IPython kernel'
        ),
    }

    # connection info:
    connection_dir = Unicode()

    @default('connection_dir')
    def _default_connection_dir(self):
        return jupyter_runtime_dir()

    @property
    def abs_connection_file(self):
        if os.path.basename(self.connection_file) == self.connection_file:
            return os.path.join(self.connection_dir, self.connection_file)
        else:
            return self.connection_file

    # streams, etc.
    no_stdout = Bool(False, help="redirect stdout to the null device").tag(config=True)
    no_stderr = Bool(False, help="redirect stderr to the null device").tag(config=True)
    outstream_class = DottedObjectName('ipykernel.iostream.OutStream',
        help="The importstring for the OutStream factory").tag(config=True)
    displayhook_class = DottedObjectName('ipykernel.displayhook.ZMQDisplayHook',
        help="The importstring for the DisplayHook factory").tag(config=True)

    # polling
    parent_handle = Integer(int(os.environ.get('JPY_PARENT_PID') or 0),
        help="""kill this process if its parent dies.  On Windows, the argument
        specifies the HANDLE of the parent process, otherwise it is simply boolean.
        """).tag(config=True)
    interrupt = Integer(int(os.environ.get('JPY_INTERRUPT_EVENT') or 0),
        help="""ONLY USED ON WINDOWS
        Interrupt this process when the parent is signaled.
        """).tag(config=True)

    def init_crash_handler(self):
        sys.excepthook = self.excepthook

    def excepthook(self, etype, evalue, tb):
        # write uncaught traceback to 'real' stderr, not zmq-forwarder
        traceback.print_exception(etype, evalue, tb, file=sys.__stderr__)

    def init_poller(self):
        if sys.platform == 'win32':
            if self.interrupt or self.parent_handle:
                self.poller = ParentPollerWindows(self.interrupt, self.parent_handle)
        elif self.parent_handle:
            self.poller = ParentPollerUnix()

    def _bind_socket(self, s, port):
        iface = '%s://%s' % (self.transport, self.ip)
        if self.transport == 'tcp':
            if port <= 0:
                port = s.bind_to_random_port(iface)
            else:
                s.bind("tcp://%s:%i" % (self.ip, port))
        elif self.transport == 'ipc':
            if port <= 0:
                port = 1
                path = "%s-%i" % (self.ip, port)
                while os.path.exists(path):
                    port = port + 1
                    path = "%s-%i" % (self.ip, port)
            else:
                path = "%s-%i" % (self.ip, port)
            s.bind("ipc://%s" % path)
        return port

    def write_connection_file(self):
        """write connection info to JSON file"""
        cf = self.abs_connection_file
        self.log.debug("Writing connection file: %s", cf)
        write_connection_file(cf, ip=self.ip, key=self.session.key, transport=self.transport,
        shell_port=self.shell_port, stdin_port=self.stdin_port, hb_port=self.hb_port,
        iopub_port=self.iopub_port, control_port=self.control_port)

    def cleanup_connection_file(self):
        cf = self.abs_connection_file
        self.log.debug("Cleaning up connection file: %s", cf)
        try:
            os.remove(cf)
        except (IOError, OSError):
            pass

        self.cleanup_ipc_files()

    def init_connection_file(self):
        if not self.connection_file:
            self.connection_file = "kernel-%s.json"%os.getpid()
        try:
            self.connection_file = filefind(self.connection_file, ['.', self.connection_dir])
        except IOError:
            self.log.debug("Connection file not found: %s", self.connection_file)
            # This means I own it, and I'll create it in this directory:
            ensure_dir_exists(os.path.dirname(self.abs_connection_file), 0o700)
            # Also, I will clean it up:
            atexit.register(self.cleanup_connection_file)
            return
        try:
            self.load_connection_file()
        except Exception:
            self.log.error("Failed to load connection file: %r", self.connection_file, exc_info=True)
            self.exit(1)

    def init_sockets(self):
        # Create a context, a session, and the kernel sockets.
        self.log.info("Starting the kernel at pid: %i", os.getpid())
        context = zmq.Context.instance()
        # Uncomment this to try closing the context.
        # atexit.register(context.term)

        self.shell_socket = context.socket(zmq.ROUTER)
        self.shell_socket.linger = 1000
        self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
        self.log.debug("shell ROUTER Channel on port: %i" % self.shell_port)

        self.stdin_socket = context.socket(zmq.ROUTER)
        self.stdin_socket.linger = 1000
        self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port)
        self.log.debug("stdin ROUTER Channel on port: %i" % self.stdin_port)

        self.control_socket = context.socket(zmq.ROUTER)
        self.control_socket.linger = 1000
        self.control_port = self._bind_socket(self.control_socket, self.control_port)
        self.log.debug("control ROUTER Channel on port: %i" % self.control_port)

        self.init_iopub(context)

    def init_iopub(self, context):
        self.iopub_socket = context.socket(zmq.PUB)
        self.iopub_socket.linger = 1000
        self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port)
        self.log.debug("iopub PUB Channel on port: %i" % self.iopub_port)
        self.configure_tornado_logger()
        self.iopub_thread = IOPubThread(self.iopub_socket, pipe=True)
        self.iopub_thread.start()
        # backward-compat: wrap iopub socket API in background thread
        self.iopub_socket = self.iopub_thread.background_socket

    def init_heartbeat(self):
        """start the heart beating"""
        # heartbeat doesn't share context, because it mustn't be blocked
        # by the GIL, which is accessed by libzmq when freeing zero-copy messages
        hb_ctx = zmq.Context()
        self.heartbeat = Heartbeat(hb_ctx, (self.transport, self.ip, self.hb_port))
        self.hb_port = self.heartbeat.port
        self.log.debug("Heartbeat REP Channel on port: %i" % self.hb_port)
        self.heartbeat.start()

    def log_connection_info(self):
        """display connection info, and store ports"""
        basename = os.path.basename(self.connection_file)
        if basename == self.connection_file or \
            os.path.dirname(self.connection_file) == self.connection_dir:
            # use shortname
            tail = basename
        else:
            tail = self.connection_file
        lines = [
            "To connect another client to this kernel, use:",
            "    --existing %s" % tail,
        ]
        # log connection info
        # info-level, so often not shown.
        # frontends should use the %connect_info magic
        # to see the connection info
        for line in lines:
            self.log.info(line)
        # also raw print to the terminal if no parent_handle (`ipython kernel`)
        # unless log-level is CRITICAL (--quiet)
        if not self.parent_handle and self.log_level < logging.CRITICAL:
            io.rprint(_ctrl_c_message)
            for line in lines:
                io.rprint(line)

        self.ports = dict(shell=self.shell_port, iopub=self.iopub_port,
                                stdin=self.stdin_port, hb=self.hb_port,
                                control=self.control_port)

    def init_blackhole(self):
        """redirects stdout/stderr to devnull if necessary"""
        if self.no_stdout or self.no_stderr:
            blackhole = open(os.devnull, 'w')
            if self.no_stdout:
                sys.stdout = sys.__stdout__ = blackhole
            if self.no_stderr:
                sys.stderr = sys.__stderr__ = blackhole

    def init_io(self):
        """Redirect input streams and set a display hook."""
        if self.outstream_class:
            outstream_factory = import_item(str(self.outstream_class))
            sys.stdout = outstream_factory(self.session, self.iopub_thread, u'stdout')
            sys.stderr = outstream_factory(self.session, self.iopub_thread, u'stderr')
        if self.displayhook_class:
            displayhook_factory = import_item(str(self.displayhook_class))
            self.displayhook = displayhook_factory(self.session, self.iopub_socket)
            sys.displayhook = self.displayhook

        self.patch_io()

    def patch_io(self):
        """Patch important libraries that can't handle sys.stdout forwarding"""
        try:
            import faulthandler
        except ImportError:
            pass
        else:
            # Warning: this is a monkeypatch of `faulthandler.enable`, watch for possible
            # updates to the upstream API and update accordingly (up-to-date as of Python 3.5):
            # https://docs.python.org/3/library/faulthandler.html#faulthandler.enable

            # change default file to __stderr__ from forwarded stderr
            faulthandler_enable = faulthandler.enable
            def enable(file=sys.__stderr__, all_threads=True, **kwargs):
                return faulthandler_enable(file=file, all_threads=all_threads, **kwargs)

            faulthandler.enable = enable

            if hasattr(faulthandler, 'register'):
                faulthandler_register = faulthandler.register
                def register(signum, file=sys.__stderr__, all_threads=True, chain=False, **kwargs):
                    return faulthandler_register(signum, file=file, all_threads=all_threads,
                                                 chain=chain, **kwargs)
                faulthandler.register = register

    def init_signal(self):
        signal.signal(signal.SIGINT, signal.SIG_IGN)

    def init_kernel(self):
        """Create the Kernel object itself"""
        shell_stream = ZMQStream(self.shell_socket)
        control_stream = ZMQStream(self.control_socket)

        kernel_factory = self.kernel_class.instance

        kernel = kernel_factory(parent=self, session=self.session,
                                shell_streams=[shell_stream, control_stream],
                                iopub_thread=self.iopub_thread,
                                iopub_socket=self.iopub_socket,
                                stdin_socket=self.stdin_socket,
                                log=self.log,
                                profile_dir=self.profile_dir,
                                user_ns=self.user_ns,
        )
        kernel.record_ports({
            name + '_port': port for name, port in self.ports.items()
        })
        self.kernel = kernel

        # Allow the displayhook to get the execution count
        self.displayhook.get_execution_count = lambda: kernel.execution_count

    def init_gui_pylab(self):
        """Enable GUI event loop integration, taking pylab into account."""

        # Register inline backend as default
        # this is higher priority than matplotlibrc,
        # but lower priority than anything else (mpl.use() for instance).
        # This only affects matplotlib >= 1.5
        if not os.environ.get('MPLBACKEND'):
            os.environ['MPLBACKEND'] = 'module://ipykernel.pylab.backend_inline'

        # Provide a wrapper for :meth:`InteractiveShellApp.init_gui_pylab`
        # to ensure that any exception is printed straight to stderr.
        # Normally _showtraceback associates the reply with an execution,
        # which means frontends will never draw it, as this exception
        # is not associated with any execute request.

        shell = self.shell
        _showtraceback = shell._showtraceback
        try:
            # replace error-sending traceback with stderr
            def print_tb(etype, evalue, stb):
                print ("GUI event loop or pylab initialization failed",
                       file=sys.stderr)
                print (shell.InteractiveTB.stb2text(stb), file=sys.stderr)
            shell._showtraceback = print_tb
            InteractiveShellApp.init_gui_pylab(self)
        finally:
            shell._showtraceback = _showtraceback

    def init_shell(self):
        self.shell = getattr(self.kernel, 'shell', None)
        if self.shell:
            self.shell.configurables.append(self)

    def init_extensions(self):
        super(IPKernelApp, self).init_extensions()
        # BEGIN HARDCODED WIDGETS HACK
        # Ensure ipywidgets extension is loaded if available
        extension_man = self.shell.extension_manager
        if 'ipywidgets' not in extension_man.loaded:
            try:
                extension_man.load_extension('ipywidgets')
            except ImportError as e:
                self.log.debug('ipywidgets package not installed.  Widgets will not be available.')
        # END HARDCODED WIDGETS HACK

    def configure_tornado_logger(self):
        """ Configure the tornado logging.Logger.

            Must set up the tornado logger or else tornado will call
            basicConfig for the root logger which makes the root logger
            go to the real sys.stderr instead of the capture streams.
            This function mimics the setup of logging.basicConfig.
        """
        logger = logging.getLogger('tornado')
        handler = logging.StreamHandler()
        formatter = logging.Formatter(logging.BASIC_FORMAT)
        handler.setFormatter(formatter)
        logger.addHandler(handler)

    @catch_config_error
    def initialize(self, argv=None):
        super(IPKernelApp, self).initialize(argv)
        if self.subapp is not None:
            return
        # register zmq IOLoop with tornado
        zmq_ioloop.install()
        self.init_blackhole()
        self.init_connection_file()
        self.init_poller()
        self.init_sockets()
        self.init_heartbeat()
        # writing/displaying connection info must be *after* init_sockets/heartbeat
        self.write_connection_file()
        # Log connection info after writing connection file, so that the connection
        # file is definitely available at the time someone reads the log.
        self.log_connection_info()
        self.init_io()
        self.init_signal()
        self.init_kernel()
        # shell init steps
        self.init_path()
        self.init_shell()
        if self.shell:
            self.init_gui_pylab()
            self.init_extensions()
            self.init_code()
        # flush stdout/stderr, so that anything written to these streams during
        # initialization do not get associated with the first execution request
        sys.stdout.flush()
        sys.stderr.flush()

    def start(self):
        if self.subapp is not None:
            return self.subapp.start()
        if self.poller is not None:
            self.poller.start()
        self.kernel.start()
        try:
            ioloop.IOLoop.instance().start()
        except KeyboardInterrupt:
            pass
Exemplo n.º 8
0
class ModelInspector(HasTraits):
    """Main entry point to a Kepler session.

    This class orchestrates model logging, inspection, running experiments and
    saving results.
    """

    # Kepler config home
    home = Directory()

    project = Unicode('default')

    config = Instance(ConfigParser)

    # the keras / sklearn model
    model = Union([Instance(Model), Instance(BaseEstimator)])

    # last git commit associated with the model
    commit = Unicode()

    # path to the saved weights of the model
    weights_path = KerasModelWeights()

    # Yaml config of the model, written to a file
    model_config = File()

    # Type of model, keras.engine.training.{Model, Sequential}, etc
    keras_type = Unicode()

    # Index of the archmat corresponding to this model.
    archmat_index = Integer()

    enable_model_search = Bool(True)

    checks = List()

    model_checks = List()

    name = Unicode()

    @default('checks')
    def _default_checks(self):
        subcls = C.BaseStartTrainingCheck.__subclasses__()
        return [c() for c in subcls if c.enabled]

    @default('home')
    def _default_home(self):
        return os.environ.get('KEPLER_HOME', op.expanduser('~/.kepler'))

    @default('config')
    def _default_config(self):
        cfg = ConfigParser(interpolation=ExtendedInterpolation())
        cfg.read(op.join(self.home, 'config.ini'))
        return cfg

    @default('name')
    def _default_name(self):
        if isinstance(self.model, BaseEstimator):
            name = type(self.model).__name__
        elif isinstance(self.model, Model):
            name = name_keras_model(self.model)
        else:
            name = ''
        return name

    @property
    def db_engine(self):
        engine = getattr(self, '_db_engine', False)
        if not engine:
            dbpath = self.config.get('default', 'db')
            self._db_engine = get_engine(dbpath)
            engine = self._db_engine
        return engine

    @property
    def model_vectorizer(self):
        vect = getattr(self, '_model_vectorizer', False)
        if not vect:
            self._model_vectorizer = get_model_vectorizer(
                self.config.get('models', 'vectorizer'))
            vect = self._model_vectorizer
        return vect

    def __init__(self, *args, **kwargs):
        """
        Overwritten from parent to include the created timestamp.
        """
        super(ModelInspector, self).__init__(*args, **kwargs)
        # overwrite the model.save method to be able to save the model
        # weightspath here.
        self.created = datetime.now()

    @property
    def keras_type(self):
        """Type of keras model.
        """
        return self.model.__class__.__name__

    @property
    def n_params(self):
        """Count number of trainable parameters."""
        return count_params(self.model)

    @property
    def n_layers(self):
        """Number of trainable layers."""
        return sum([c.trainable for c in self.model.layers])

    def write_model_config(self):
        """Write the model's config to a yaml file.

        The default location is ~/.kepler/models/specs, which is controlled
        from the ('models', 'spec_dir') config option."""
        if not self.model_config:
            uid = uuid4()
            specs_dir = op.expanduser(self.config.get('models', 'spec_dir'))
            if not op.isdir(specs_dir):
                os.makedirs(specs_dir)
            outpath = op.join(specs_dir, str(uid) + '.txt')
            with open(outpath, 'w') as fout:
                fout.write(self.model.to_yaml())
            self.model_config = outpath

    def write_model_arch_vector(self, x=None):
        """
        Write the vectorized representation of the model architecture to the
        archmat file.

        Parameters
        ----------
        x : sparse vector, optional
            The sparse vector representing a model. If not specified, it is
            calculated for the current model.
        """
        if not x:
            x = model_representation(self.model, self.model_vectorizer)
        matpath = self.config.get('models', 'model_archs')
        X = load_model_arch_mat(matpath)
        if X is None:
            X = x
        else:
            X = vstack((X, x))
        self.archmat_index = X.shape[0] - 1
        write_model_arch_mat(X, matpath)

    def __enter__(self):
        """Setup a Kepler session by:

        1. adding the current model to the db
        2. saving the model config
        3. searching for similar models
        4. creating a modelproxy for the user to work with
        """
        self.instance = ModelDBModel()
        self.session = sessionmaker(bind=self.db_engine)()
        self.session.add(self.instance)

        try:
            self.session.commit()
        except OperationalError:
            raise RuntimeError('Kepler may not have initialized properly.',
                               'Please run kepler init and try again.')
        self.write_model_config()
        if self.enable_model_search:
            if self.config.get('models', 'enable_model_search'):
                x = model_representation(self.model, self.model_vectorizer)
                self.search(x)
        self.run_model_checks()
        self.model_proxy = ModelProxy(self.model, self, self.checks)
        self.model_proxy.setUp()
        return self.model_proxy

    def __exit__(self, _type, value, traceback):
        """Teardown the Kepler session by:

        1. undoing the modelproxy
        2. writing the model architecture to the archmat
        3. saving the model metadata to the db.
        """
        self.model_proxy.tearDown()
        self.write_model_arch_vector()
        self.save()

    def save(self):
        """Save the model details to the Kepler db."""
        table_columns = ModelDBModel.__table__.columns
        attrs = [k.name for k in table_columns if not k.primary_key]
        for attr in attrs:
            setattr(self.instance, attr, getattr(self, attr))
        self.session.add(self.instance)
        # add current model to project
        p2model = ProjectModel(project_id=self.project, model=self.instance)
        self.session.add(p2model)
        self.session.commit()
        self.session.close()

    def search(self, x=None, prompt=True):
        """Search the archmat for similar models.

        Parameters
        ----------

        x : sparse vector, optional
            The sparse vector to search. If not specified, computed for the
            current model.
        prompt : bool, optional
            Whether to prompt the user if similar models are found.
        """
        if x is None:
            x = model_representation(self.model, self.model_vectorizer)
        X = load_model_arch_mat(self.config.get('models', 'model_archs'))
        if X is None:  # nothing to search against
            return
        d = cosine_similarity(x, X).ravel()
        thresh = self.config.get('misc', 'model_sim_tol')
        d = d > float(thresh)
        if np.any(d):
            indices, = np.where(d)
            if prompt:
                n_similar = d.sum()
                print('There are {} models similar to this one.'.format(
                    n_similar))
                see_archs = binary_prompt(
                    'Would you like to see their graphs?')
                if see_archs:
                    tf_logdir = self.config.get('models', 'tensorflow_logdir')
                    print('Enter location for saving graphs [{}]: '.format(
                        tf_logdir))
                    user_choice = input('>>> ')
                    if user_choice:
                        tf_logdir = user_choice
                    tf_logdir = op.expanduser(tf_logdir)
                    with GraphWriter(logdir=tf_logdir) as gw:
                        gw.write_graphs(self.get_model_configs(indices))
                    print('Graphs written to ' + tf_logdir)
                    print('Please point Tensorboard to ' + tf_logdir)
            continue_training = binary_prompt('Continue training?')
            if not continue_training:
                import sys
                sys.exit()
            return indices

    def get_model_configs(self, indices):
        """Iterate over model config files.

        For models specified in `indices`, iterate over the corresponding
        `model_config` column values, which are paths to files containing the
        model summaries.

        Parameters
        ----------

        indices : sequence
            Sequence of DB indices over which to iterate and find the model
            summaries.

        Yields
        ------
        str
            path to a yaml file containing the config of a model
        """
        klass = self.instance.__class__
        q = self.session.query(klass)
        for inst in q.filter(
                klass.archmat_index.in_(map(lambda x: x.item(), indices))):
            yield inst.model_config

    def run_model_checks(self):
        """Run all checks enabled at the model level."""
        if not self.model_checks:
            checks = [
                c for c in C.BaseModelCheck.__subclasses__() if c.enabled
            ]
            for check in checks:
                check()(self.model)
        else:
            for check in self.model_checks:
                check(self.model)
Exemplo n.º 9
0
class Kernel(SingletonConfigurable):

    #---------------------------------------------------------------------------
    # Kernel interface
    #---------------------------------------------------------------------------

    # attribute to override with a GUI
    eventloop = Any(None)

    @observe('eventloop')
    def _update_eventloop(self, change):
        """schedule call to eventloop from IOLoop"""
        loop = ioloop.IOLoop.instance()
        loop.add_callback(self.enter_eventloop)

    session = Instance(Session, allow_none=True)
    profile_dir = Instance('IPython.core.profiledir.ProfileDir', allow_none=True)
    shell_streams = List()
    control_stream = Instance(ZMQStream, allow_none=True)
    iopub_socket = Any()
    iopub_thread = Any()
    stdin_socket = Any()
    log = Instance(logging.Logger, allow_none=True)

    # identities:
    int_id = Integer(-1)
    ident = Unicode()

    @default('ident')
    def _default_ident(self):
        return unicode_type(uuid.uuid4())

    # This should be overridden by wrapper kernels that implement any real
    # language.
    language_info = {}

    # any links that should go in the help menu
    help_links = List()

    # Private interface

    _darwin_app_nap = Bool(True,
        help="""Whether to use appnope for compatiblity with OS X App Nap.

        Only affects OS X >= 10.9.
        """
    ).tag(config=True)

    # track associations with current request
    _allow_stdin = Bool(False)
    _parent_header = Dict()
    _parent_ident = Any(b'')
    # Time to sleep after flushing the stdout/err buffers in each execute
    # cycle.  While this introduces a hard limit on the minimal latency of the
    # execute cycle, it helps prevent output synchronization problems for
    # clients.
    # Units are in seconds.  The minimum zmq latency on local host is probably
    # ~150 microseconds, set this to 500us for now.  We may need to increase it
    # a little if it's not enough after more interactive testing.
    _execute_sleep = Float(0.0005).tag(config=True)

    # Frequency of the kernel's event loop.
    # Units are in seconds, kernel subclasses for GUI toolkits may need to
    # adapt to milliseconds.
    _poll_interval = Float(0.05).tag(config=True)

    # If the shutdown was requested over the network, we leave here the
    # necessary reply message so it can be sent by our registered atexit
    # handler.  This ensures that the reply is only sent to clients truly at
    # the end of our shutdown process (which happens after the underlying
    # IPython shell's own shutdown).
    _shutdown_message = None

    # This is a dict of port number that the kernel is listening on. It is set
    # by record_ports and used by connect_request.
    _recorded_ports = Dict()

    # set of aborted msg_ids
    aborted = Set()

    # Track execution count here. For IPython, we override this to use the
    # execution count we store in the shell.
    execution_count = 0

    msg_types = [
        'execute_request', 'complete_request',
        'inspect_request', 'history_request',
        'comm_info_request', 'kernel_info_request',
        'connect_request', 'shutdown_request',
        'is_complete_request',
        # deprecated:
        'apply_request',
    ]
    # add deprecated ipyparallel control messages
    control_msg_types = msg_types + ['clear_request', 'abort_request']

    def __init__(self, **kwargs):
        super(Kernel, self).__init__(**kwargs)

        # Build dict of handlers for message types
        self.shell_handlers = {}
        for msg_type in self.msg_types:
            self.shell_handlers[msg_type] = getattr(self, msg_type)

        self.control_handlers = {}
        for msg_type in self.control_msg_types:
            self.control_handlers[msg_type] = getattr(self, msg_type)

    def dispatch_control(self, msg):
        """dispatch control requests"""
        idents,msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.deserialize(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Control Message", exc_info=True)
            return

        self.log.debug("Control received: %s", msg)

        # Set the parent message for side effects.
        self.set_parent(idents, msg)
        self._publish_status(u'busy')

        header = msg['header']
        msg_type = header['msg_type']

        handler = self.control_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type)
        else:
            try:
                handler(self.control_stream, idents, msg)
            except Exception:
                self.log.error("Exception in control handler:", exc_info=True)

        sys.stdout.flush()
        sys.stderr.flush()
        self._publish_status(u'idle')

    def should_handle(self, stream, msg, idents):
        """Check whether a shell-channel message should be handled

        Allows subclasses to prevent handling of certain messages (e.g. aborted requests).
        """
        msg_id = msg['header']['msg_id']
        if msg_id in self.aborted:
            msg_type = msg['header']['msg_type']
            # is it safe to assume a msg_id will not be resubmitted?
            self.aborted.remove(msg_id)
            reply_type = msg_type.split('_')[0] + '_reply'
            status = {'status' : 'aborted'}
            md = {'engine' : self.ident}
            md.update(status)
            self.session.send(stream, reply_type, metadata=md,
                        content=status, parent=msg, ident=idents)
            return False
        return True

    def dispatch_shell(self, stream, msg):
        """dispatch shell requests"""
        # flush control requests first
        if self.control_stream:
            self.control_stream.flush()

        idents,msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.deserialize(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Message", exc_info=True)
            return

        # Set the parent message for side effects.
        self.set_parent(idents, msg)
        self._publish_status(u'busy')

        header = msg['header']
        msg_id = header['msg_id']
        msg_type = msg['header']['msg_type']

        # Print some info about this message and leave a '--->' marker, so it's
        # easier to trace visually the message chain when debugging.  Each
        # handler prints its message at the end.
        self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type)
        self.log.debug('   Content: %s\n   --->\n   ', msg['content'])

        if not self.should_handle(stream, msg, idents):
            return

        handler = self.shell_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("UNKNOWN MESSAGE TYPE: %r", msg_type)
        else:
            self.log.debug("%s: %s", msg_type, msg)
            self.pre_handler_hook()
            try:
                handler(stream, idents, msg)
            except Exception:
                self.log.error("Exception in message handler:", exc_info=True)
            finally:
                self.post_handler_hook()

        sys.stdout.flush()
        sys.stderr.flush()
        self._publish_status(u'idle')

    def pre_handler_hook(self):
        """Hook to execute before calling message handler"""
        # ensure default_int_handler during handler call
        self.saved_sigint_handler = signal(SIGINT, default_int_handler)

    def post_handler_hook(self):
        """Hook to execute after calling message handler"""
        signal(SIGINT, self.saved_sigint_handler)

    def enter_eventloop(self):
        """enter eventloop"""
        self.log.info("entering eventloop %s", self.eventloop)
        for stream in self.shell_streams:
            # flush any pending replies,
            # which may be skipped by entering the eventloop
            stream.flush(zmq.POLLOUT)
        # restore default_int_handler
        signal(SIGINT, default_int_handler)
        while self.eventloop is not None:
            try:
                self.eventloop(self)
            except KeyboardInterrupt:
                # Ctrl-C shouldn't crash the kernel
                self.log.error("KeyboardInterrupt caught in kernel")
                continue
            else:
                # eventloop exited cleanly, this means we should stop (right?)
                self.eventloop = None
                break
        self.log.info("exiting eventloop")

    def start(self):
        """register dispatchers for streams"""
        if self.control_stream:
            self.control_stream.on_recv(self.dispatch_control, copy=False)

        def make_dispatcher(stream):
            def dispatcher(msg):
                return self.dispatch_shell(stream, msg)
            return dispatcher

        for s in self.shell_streams:
            s.on_recv(make_dispatcher(s), copy=False)

        # publish idle status
        self._publish_status('starting')

    def do_one_iteration(self):
        """step eventloop just once"""
        if self.control_stream:
            self.control_stream.flush()
        for stream in self.shell_streams:
            # handle at most one request per iteration
            stream.flush(zmq.POLLIN, 1)
            stream.flush(zmq.POLLOUT)

    def record_ports(self, ports):
        """Record the ports that this kernel is using.

        The creator of the Kernel instance must call this methods if they
        want the :meth:`connect_request` method to return the port numbers.
        """
        self._recorded_ports = ports

    #---------------------------------------------------------------------------
    # Kernel request handlers
    #---------------------------------------------------------------------------

    def _publish_execute_input(self, code, parent, execution_count):
        """Publish the code request on the iopub stream."""

        self.session.send(self.iopub_socket, u'execute_input',
                            {u'code':code, u'execution_count': execution_count},
                            parent=parent, ident=self._topic('execute_input')
        )

    def _publish_status(self, status, parent=None):
        """send status (busy/idle) on IOPub"""
        self.session.send(self.iopub_socket,
                          u'status',
                          {u'execution_state': status},
                          parent=parent or self._parent_header,
                          ident=self._topic('status'),
                          )

    def set_parent(self, ident, parent):
        """Set the current parent_header

        Side effects (IOPub messages) and replies are associated with
        the request that caused them via the parent_header.

        The parent identity is used to route input_request messages
        on the stdin channel.
        """
        self._parent_ident = ident
        self._parent_header = parent

    def send_response(self, stream, msg_or_type, content=None, ident=None,
             buffers=None, track=False, header=None, metadata=None):
        """Send a response to the message we're currently processing.

        This accepts all the parameters of :meth:`jupyter_client.session.Session.send`
        except ``parent``.

        This relies on :meth:`set_parent` having been called for the current
        message.
        """
        return self.session.send(stream, msg_or_type, content, self._parent_header,
                                 ident, buffers, track, header, metadata)

    def init_metadata(self, parent):
        """Initialize metadata.

        Run at the beginning of execution requests.
        """
        return {
            'started': datetime.now(),
        }

    def finish_metadata(self, parent, metadata, reply_content):
        """Finish populating metadata.

        Run after completing an execution request.
        """
        return metadata

    def execute_request(self, stream, ident, parent):
        """handle an execute_request"""

        try:
            content = parent[u'content']
            code = py3compat.cast_unicode_py2(content[u'code'])
            silent = content[u'silent']
            store_history = content.get(u'store_history', not silent)
            user_expressions = content.get('user_expressions', {})
            allow_stdin = content.get('allow_stdin', False)
        except:
            self.log.error("Got bad msg: ")
            self.log.error("%s", parent)
            return

        stop_on_error = content.get('stop_on_error', True)

        metadata = self.init_metadata(parent)

        # Re-broadcast our input for the benefit of listening clients, and
        # start computing output
        if not silent:
            self.execution_count += 1
            self._publish_execute_input(code, parent, self.execution_count)

        reply_content = self.do_execute(code, silent, store_history,
                                        user_expressions, allow_stdin)

        # Flush output before sending the reply.
        sys.stdout.flush()
        sys.stderr.flush()
        # FIXME: on rare occasions, the flush doesn't seem to make it to the
        # clients... This seems to mitigate the problem, but we definitely need
        # to better understand what's going on.
        if self._execute_sleep:
            time.sleep(self._execute_sleep)

        # Send the reply.
        reply_content = json_clean(reply_content)
        metadata = self.finish_metadata(parent, metadata, reply_content)

        reply_msg = self.session.send(stream, u'execute_reply',
                                      reply_content, parent, metadata=metadata,
                                      ident=ident)

        self.log.debug("%s", reply_msg)

        if not silent and reply_msg['content']['status'] == u'error' and stop_on_error:
            self._abort_queues()

    def do_execute(self, code, silent, store_history=True,
                   user_expressions=None, allow_stdin=False):
        """Execute user code. Must be overridden by subclasses.
        """
        raise NotImplementedError

    def complete_request(self, stream, ident, parent):
        content = parent['content']
        code = content['code']
        cursor_pos = content['cursor_pos']

        matches = self.do_complete(code, cursor_pos)
        matches = json_clean(matches)
        completion_msg = self.session.send(stream, 'complete_reply',
                                           matches, parent, ident)
        self.log.debug("%s", completion_msg)

    def do_complete(self, code, cursor_pos):
        """Override in subclasses to find completions.
        """
        return {'matches' : [],
                'cursor_end' : cursor_pos,
                'cursor_start' : cursor_pos,
                'metadata' : {},
                'status' : 'ok'}

    def inspect_request(self, stream, ident, parent):
        content = parent['content']

        reply_content = self.do_inspect(content['code'], content['cursor_pos'],
                                        content.get('detail_level', 0))
        # Before we send this object over, we scrub it for JSON usage
        reply_content = json_clean(reply_content)
        msg = self.session.send(stream, 'inspect_reply',
                                reply_content, parent, ident)
        self.log.debug("%s", msg)

    def do_inspect(self, code, cursor_pos, detail_level=0):
        """Override in subclasses to allow introspection.
        """
        return {'status': 'ok', 'data': {}, 'metadata': {}, 'found': False}

    def history_request(self, stream, ident, parent):
        content = parent['content']

        reply_content = self.do_history(**content)

        reply_content = json_clean(reply_content)
        msg = self.session.send(stream, 'history_reply',
                                reply_content, parent, ident)
        self.log.debug("%s", msg)

    def do_history(self, hist_access_type, output, raw, session=None, start=None,
                   stop=None, n=None, pattern=None, unique=False):
        """Override in subclasses to access history.
        """
        return {'history': []}

    def connect_request(self, stream, ident, parent):
        if self._recorded_ports is not None:
            content = self._recorded_ports.copy()
        else:
            content = {}
        msg = self.session.send(stream, 'connect_reply',
                                content, parent, ident)
        self.log.debug("%s", msg)

    @property
    def kernel_info(self):
        return {
            'protocol_version': kernel_protocol_version,
            'implementation': self.implementation,
            'implementation_version': self.implementation_version,
            'language_info': self.language_info,
            'banner': self.banner,
            'help_links': self.help_links,
        }

    def kernel_info_request(self, stream, ident, parent):
        msg = self.session.send(stream, 'kernel_info_reply',
                                self.kernel_info, parent, ident)
        self.log.debug("%s", msg)

    def comm_info_request(self, stream, ident, parent):
        content = parent['content']
        target_name = content.get('target_name', None)

        # Should this be moved to ipkernel?
        if hasattr(self, 'comm_manager'):
            comms = {
                k: dict(target_name=v.target_name)
                for (k, v) in self.comm_manager.comms.items()
                if v.target_name == target_name or target_name is None
            }
        else:
            comms = {}
        reply_content = dict(comms=comms)
        msg = self.session.send(stream, 'comm_info_reply',
                                reply_content, parent, ident)
        self.log.debug("%s", msg)

    def shutdown_request(self, stream, ident, parent):
        content = self.do_shutdown(parent['content']['restart'])
        self.session.send(stream, u'shutdown_reply', content, parent, ident=ident)
        # same content, but different msg_id for broadcasting on IOPub
        self._shutdown_message = self.session.msg(u'shutdown_reply',
                                                  content, parent
        )

        self._at_shutdown()
        # call sys.exit after a short delay
        loop = ioloop.IOLoop.instance()
        loop.add_timeout(time.time()+0.1, loop.stop)

    def do_shutdown(self, restart):
        """Override in subclasses to do things when the frontend shuts down the
        kernel.
        """
        return {'status': 'ok', 'restart': restart}

    def is_complete_request(self, stream, ident, parent):
        content = parent['content']
        code = content['code']

        reply_content = self.do_is_complete(code)
        reply_content = json_clean(reply_content)
        reply_msg = self.session.send(stream, 'is_complete_reply',
                                           reply_content, parent, ident)
        self.log.debug("%s", reply_msg)

    def do_is_complete(self, code):
        """Override in subclasses to find completions.
        """
        return {'status' : 'unknown',
                }

    #---------------------------------------------------------------------------
    # Engine methods (DEPRECATED)
    #---------------------------------------------------------------------------

    def apply_request(self, stream, ident, parent):
        self.log.warn("""apply_request is deprecated in kernel_base, moving to ipyparallel.""")
        try:
            content = parent[u'content']
            bufs = parent[u'buffers']
            msg_id = parent['header']['msg_id']
        except:
            self.log.error("Got bad msg: %s", parent, exc_info=True)
            return

        md = self.init_metadata(parent)

        reply_content, result_buf = self.do_apply(content, bufs, msg_id, md)

        # flush i/o
        sys.stdout.flush()
        sys.stderr.flush()

        md = self.finish_metadata(parent, md, reply_content)

        self.session.send(stream, u'apply_reply', reply_content,
                    parent=parent, ident=ident,buffers=result_buf, metadata=md)

    def do_apply(self, content, bufs, msg_id, reply_metadata):
        """DEPRECATED"""
        raise NotImplementedError

    #---------------------------------------------------------------------------
    # Control messages (DEPRECATED)
    #---------------------------------------------------------------------------

    def abort_request(self, stream, ident, parent):
        """abort a specific msg by id"""
        self.log.warn("abort_request is deprecated in kernel_base. It os only part of IPython parallel")
        msg_ids = parent['content'].get('msg_ids', None)
        if isinstance(msg_ids, string_types):
            msg_ids = [msg_ids]
        if not msg_ids:
            self._abort_queues()
        for mid in msg_ids:
            self.aborted.add(str(mid))

        content = dict(status='ok')
        reply_msg = self.session.send(stream, 'abort_reply', content=content,
                parent=parent, ident=ident)
        self.log.debug("%s", reply_msg)

    def clear_request(self, stream, idents, parent):
        """Clear our namespace."""
        self.log.warn("clear_request is deprecated in kernel_base. It os only part of IPython parallel")
        content = self.do_clear()
        self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
                content = content)

    def do_clear(self):
        """DEPRECATED"""
        raise NotImplementedError

    #---------------------------------------------------------------------------
    # Protected interface
    #---------------------------------------------------------------------------

    def _topic(self, topic):
        """prefixed topic for IOPub messages"""
        base = "kernel.%s" % self.ident

        return py3compat.cast_bytes("%s.%s" % (base, topic))

    def _abort_queues(self):
        for stream in self.shell_streams:
            if stream:
                self._abort_queue(stream)

    def _abort_queue(self, stream):
        poller = zmq.Poller()
        poller.register(stream.socket, zmq.POLLIN)
        while True:
            idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
            if msg is None:
                return

            self.log.info("Aborting:")
            self.log.info("%s", msg)
            msg_type = msg['header']['msg_type']
            reply_type = msg_type.split('_')[0] + '_reply'

            status = {'status' : 'aborted'}
            md = {'engine' : self.ident}
            md.update(status)
            reply_msg = self.session.send(stream, reply_type, metadata=md,
                        content=status, parent=msg, ident=idents)
            self.log.debug("%s", reply_msg)
            # We need to wait a bit for requests to come in. This can probably
            # be set shorter for true asynchronous clients.
            poller.poll(50)

    def _no_raw_input(self):
        """Raise StdinNotImplentedError if active frontend doesn't support
        stdin."""
        raise StdinNotImplementedError("raw_input was called, but this "
                                       "frontend does not support stdin.")

    def getpass(self, prompt=''):
        """Forward getpass to frontends

        Raises
        ------
        StdinNotImplentedError if active frontend doesn't support stdin.
        """
        if not self._allow_stdin:
            raise StdinNotImplementedError(
                "getpass was called, but this frontend does not support input requests."
            )
        return self._input_request(prompt,
            self._parent_ident,
            self._parent_header,
            password=True,
        )

    def raw_input(self, prompt=''):
        """Forward raw_input to frontends

        Raises
        ------
        StdinNotImplentedError if active frontend doesn't support stdin.
        """
        if not self._allow_stdin:
            raise StdinNotImplementedError(
                "raw_input was called, but this frontend does not support input requests."
            )
        return self._input_request(prompt,
            self._parent_ident,
            self._parent_header,
            password=False,
        )

    def _input_request(self, prompt, ident, parent, password=False):
        # Flush output before making the request.
        sys.stderr.flush()
        sys.stdout.flush()
        # flush the stdin socket, to purge stale replies
        while True:
            try:
                self.stdin_socket.recv_multipart(zmq.NOBLOCK)
            except zmq.ZMQError as e:
                if e.errno == zmq.EAGAIN:
                    break
                else:
                    raise

        # Send the input request.
        content = json_clean(dict(prompt=prompt, password=password))
        self.session.send(self.stdin_socket, u'input_request', content, parent,
                          ident=ident)

        # Await a response.
        while True:
            try:
                ident, reply = self.session.recv(self.stdin_socket, 0)
            except Exception:
                self.log.warn("Invalid Message:", exc_info=True)
            except KeyboardInterrupt:
                # re-raise KeyboardInterrupt, to truncate traceback
                raise KeyboardInterrupt
            else:
                break
        try:
            value = py3compat.unicode_to_str(reply['content']['value'])
        except:
            self.log.error("Bad input_reply: %s", parent)
            value = ''
        if value == '\x04':
            # EOF
            raise EOFError
        return value

    def _at_shutdown(self):
        """Actions taken at shutdown by the kernel, called by python's atexit.
        """
        # io.rprint("Kernel at_shutdown") # dbg
        if self._shutdown_message is not None:
            self.session.send(self.iopub_socket, self._shutdown_message, ident=self._topic('shutdown'))
            self.log.debug("%s", self._shutdown_message)
        [ s.flush(zmq.POLLOUT) for s in self.shell_streams ]
Exemplo n.º 10
0
class HubAuth(SingletonConfigurable):
    """A class for authenticating with JupyterHub

    This can be used by any application.

    Use this base class only for direct, token-authenticated applications
    (web APIs).
    For applications that support direct visits from browsers,
    use HubOAuth to enable OAuth redirect-based authentication.


    If using tornado, use via :class:`HubAuthenticated` mixin.
    If using manually, use the ``.user_for_token(token_value)`` method
    to identify the user owning a given token.

    The following config must be set:

    - api_token (token for authenticating with JupyterHub API),
      fetched from the JUPYTERHUB_API_TOKEN env by default.

    The following config MAY be set:

    - api_url: the base URL of the Hub's internal API,
      fetched from JUPYTERHUB_API_URL by default.
    - cookie_cache_max_age: the number of seconds responses
      from the Hub should be cached.
    - login_url (the *public* ``/hub/login`` URL of the Hub).
    """

    hub_host = Unicode(
        '',
        help="""The public host of JupyterHub

        Only used if JupyterHub is spreading servers across subdomains.
        """,
    ).tag(config=True)

    @default('hub_host')
    def _default_hub_host(self):
        return os.getenv('JUPYTERHUB_HOST', '')

    base_url = Unicode(
        os.getenv('JUPYTERHUB_SERVICE_PREFIX') or '/',
        help="""The base URL prefix of this application

        e.g. /services/service-name/ or /user/name/

        Default: get from JUPYTERHUB_SERVICE_PREFIX
        """,
    ).tag(config=True)

    @validate('base_url')
    def _add_slash(self, proposal):
        """Ensure base_url starts and ends with /"""
        value = proposal['value']
        if not value.startswith('/'):
            value = '/' + value
        if not value.endswith('/'):
            value = value + '/'
        return value

    # where is the hub
    api_url = Unicode(
        os.getenv('JUPYTERHUB_API_URL') or 'http://127.0.0.1:8081/hub/api',
        help="""The base API URL of the Hub.

        Typically `http://hub-ip:hub-port/hub/api`
        Default: $JUPYTERHUB_API_URL
        """,
    ).tag(config=True)

    @default('api_url')
    def _api_url(self):
        env_url = os.getenv('JUPYTERHUB_API_URL')
        if env_url:
            return env_url
        else:
            return 'http://127.0.0.1:8081' + url_path_join(
                self.hub_prefix, 'api')

    api_token = Unicode(
        os.getenv('JUPYTERHUB_API_TOKEN', ''),
        help="""API key for accessing Hub API.

        Default: $JUPYTERHUB_API_TOKEN

        Loaded from services configuration in jupyterhub_config.
        Will be auto-generated for hub-managed services.
        """,
    ).tag(config=True)

    hub_prefix = Unicode(
        '/hub/',
        help="""The URL prefix for the Hub itself.

        Typically /hub/
        Default: $JUPYTERHUB_BASE_URL
        """,
    ).tag(config=True)

    @default('hub_prefix')
    def _default_hub_prefix(self):
        return url_path_join(os.getenv('JUPYTERHUB_BASE_URL') or '/',
                             'hub') + '/'

    login_url = Unicode(
        '/hub/login',
        help="""The login URL to use

        Typically /hub/login
        """,
    ).tag(config=True)

    @default('login_url')
    def _default_login_url(self):
        return self.hub_host + url_path_join(self.hub_prefix, 'login')

    keyfile = Unicode(
        os.getenv('JUPYTERHUB_SSL_KEYFILE', ''),
        help="""The ssl key to use for requests

        Use with certfile
        """,
    ).tag(config=True)

    certfile = Unicode(
        os.getenv('JUPYTERHUB_SSL_CERTFILE', ''),
        help="""The ssl cert to use for requests

        Use with keyfile
        """,
    ).tag(config=True)

    client_ca = Unicode(
        os.getenv('JUPYTERHUB_SSL_CLIENT_CA', ''),
        help="""The ssl certificate authority to use to verify requests

        Use with keyfile and certfile
        """,
    ).tag(config=True)

    cookie_options = Dict(
        help="""Additional options to pass when setting cookies.

        Can include things like `expires_days=None` for session-expiry
        or `secure=True` if served on HTTPS and default HTTPS discovery fails
        (e.g. behind some proxies).
        """).tag(config=True)

    @default('cookie_options')
    def _default_cookie_options(self):
        # load default from env
        options_env = os.environ.get('JUPYTERHUB_COOKIE_OPTIONS')
        if options_env:
            return json.loads(options_env)
        else:
            return {}

    cookie_cache_max_age = Integer(help="DEPRECATED. Use cache_max_age")

    @observe('cookie_cache_max_age')
    def _deprecated_cookie_cache(self, change):
        warnings.warn(
            "cookie_cache_max_age is deprecated in JupyterHub 0.8. Use cache_max_age instead."
        )
        self.cache_max_age = change.new

    cache_max_age = Integer(
        300,
        help=
        """The maximum time (in seconds) to cache the Hub's responses for authentication.

        A larger value reduces load on the Hub and occasional response lag.
        A smaller value reduces propagation time of changes on the Hub (rare).

        Default: 300 (five minutes)
        """,
    ).tag(config=True)
    cache = Instance(_ExpiringDict, allow_none=False)

    @default('cache')
    def _default_cache(self):
        return _ExpiringDict(self.cache_max_age)

    oauth_scopes = Set(
        Unicode(),
        help="""OAuth scopes to use for allowing access.

        Get from $JUPYTERHUB_OAUTH_SCOPES by default.
        """,
    ).tag(config=True)

    @default('oauth_scopes')
    def _default_scopes(self):
        env_scopes = os.getenv('JUPYTERHUB_OAUTH_SCOPES')
        if env_scopes:
            return set(json.loads(env_scopes))
        service_name = os.getenv("JUPYTERHUB_SERVICE_NAME")
        if service_name:
            return {f'access:services!service={service_name}'}
        return set()

    def _check_hub_authorization(self,
                                 url,
                                 api_token,
                                 cache_key=None,
                                 use_cache=True):
        """Identify a user with the Hub

        Args:
            url (str): The API URL to check the Hub for authorization
                       (e.g. http://127.0.0.1:8081/hub/api/user)
            cache_key (str): The key for checking the cache
            use_cache (bool): Specify use_cache=False to skip cached cookie values (default: True)

        Returns:
            user_model (dict): The user model, if a user is identified, None if authentication fails.

        Raises an HTTPError if the request failed for a reason other than no such user.
        """
        if use_cache:
            if cache_key is None:
                raise ValueError("cache_key is required when using cache")
            # check for a cached reply, so we don't check with the Hub if we don't have to
            try:
                return self.cache[cache_key]
            except KeyError:
                app_log.debug("HubAuth cache miss: %s", cache_key)

        data = self._api_request(
            'GET',
            url,
            headers={"Authorization": "token " + api_token},
            allow_403=True,
        )
        if data is None:
            app_log.warning("No Hub user identified for request")
        else:
            app_log.debug("Received request from Hub user %s", data)
        if use_cache:
            # cache result
            self.cache[cache_key] = data
        return data

    def _api_request(self, method, url, **kwargs):
        """Make an API request"""
        allow_403 = kwargs.pop('allow_403', False)
        headers = kwargs.setdefault('headers', {})
        headers.setdefault('Authorization', 'token %s' % self.api_token)
        if "cert" not in kwargs and self.certfile and self.keyfile:
            kwargs["cert"] = (self.certfile, self.keyfile)
            if self.client_ca:
                kwargs["verify"] = self.client_ca
        try:
            r = requests.request(method, url, **kwargs)
        except requests.ConnectionError as e:
            app_log.error("Error connecting to %s: %s", self.api_url, e)
            msg = "Failed to connect to Hub API at %r." % self.api_url
            msg += ("  Is the Hub accessible at this URL (from host: %s)?" %
                    socket.gethostname())
            if '127.0.0.1' in self.api_url:
                msg += (
                    "  Make sure to set c.JupyterHub.hub_ip to an IP accessible to"
                    +
                    " single-user servers if the servers are not on the same host as the Hub."
                )
            raise HTTPError(500, msg)

        data = None
        if r.status_code == 403 and allow_403:
            pass
        elif r.status_code == 403:
            app_log.error(
                "I don't have permission to check authorization with JupyterHub, my auth token may have expired: [%i] %s",
                r.status_code,
                r.reason,
            )
            app_log.error(r.text)
            raise HTTPError(
                500,
                "Permission failure checking authorization, I may need a new token"
            )
        elif r.status_code >= 500:
            app_log.error(
                "Upstream failure verifying auth token: [%i] %s",
                r.status_code,
                r.reason,
            )
            app_log.error(r.text)
            raise HTTPError(
                502, "Failed to check authorization (upstream problem)")
        elif r.status_code >= 400:
            app_log.warning("Failed to check authorization: [%i] %s",
                            r.status_code, r.reason)
            app_log.warning(r.text)
            msg = "Failed to check authorization"
            # pass on error from oauth failure
            try:
                response = r.json()
                # prefer more specific 'error_description', fallback to 'error'
                description = response.get(
                    "error_description", response.get("error",
                                                      "Unknown error"))
            except Exception:
                pass
            else:
                msg += ": " + description
            raise HTTPError(500, msg)
        else:
            data = r.json()

        return data

    def user_for_cookie(self, encrypted_cookie, use_cache=True, session_id=''):
        """Deprecated and removed. Use HubOAuth to authenticate browsers."""
        raise RuntimeError(
            "Identifying users by shared cookie is removed in JupyterHub 2.0. Use OAuth tokens."
        )

    def user_for_token(self, token, use_cache=True, session_id=''):
        """Ask the Hub to identify the user for a given token.

        Args:
            token (str): the token
            use_cache (bool): Specify use_cache=False to skip cached cookie values (default: True)

        Returns:
            user_model (dict): The user model, if a user is identified, None if authentication fails.

            The 'name' field contains the user's name.
        """
        return self._check_hub_authorization(
            url=url_path_join(
                self.api_url,
                "user",
            ),
            api_token=token,
            cache_key='token:{}:{}'.format(
                session_id,
                hashlib.sha256(token.encode("utf8", "replace")).hexdigest(),
            ),
            use_cache=use_cache,
        )

    auth_header_name = 'Authorization'
    auth_header_pat = re.compile(r'(?:token|bearer)\s+(.+)', re.IGNORECASE)

    def get_token(self, handler):
        """Get the user token from a request

        - in URL parameters: ?token=<token>
        - in header: Authorization: token <token>
        """

        user_token = handler.get_argument('token', '')
        if not user_token:
            # get it from Authorization header
            m = self.auth_header_pat.match(
                handler.request.headers.get(self.auth_header_name, ''))
            if m:
                user_token = m.group(1)
        return user_token

    def _get_user_cookie(self, handler):
        """Get the user model from a cookie"""
        # overridden in HubOAuth to store the access token after oauth
        return None

    def get_session_id(self, handler):
        """Get the jupyterhub session id

        from the jupyterhub-session-id cookie.
        """
        return handler.get_cookie('jupyterhub-session-id', '')

    def get_user(self, handler):
        """Get the Hub user for a given tornado handler.

        Checks cookie with the Hub to identify the current user.

        Args:
            handler (tornado.web.RequestHandler): the current request handler

        Returns:
            user_model (dict): The user model, if a user is identified, None if authentication fails.

            The 'name' field contains the user's name.
        """

        # only allow this to be called once per handler
        # avoids issues if an error is raised,
        # since this may be called again when trying to render the error page
        if hasattr(handler, '_cached_hub_user'):
            return handler._cached_hub_user

        handler._cached_hub_user = user_model = None
        session_id = self.get_session_id(handler)

        # check token first
        token = self.get_token(handler)
        if token:
            user_model = self.user_for_token(token, session_id=session_id)
            if user_model:
                handler._token_authenticated = True

        # no token, check cookie
        if user_model is None:
            user_model = self._get_user_cookie(handler)

        # cache result
        handler._cached_hub_user = user_model
        if not user_model:
            app_log.debug("No user identified")
        return user_model

    def check_scopes(self, required_scopes, user):
        """Check whether the user has required scope(s)"""
        return check_scopes(required_scopes, set(user["scopes"]))
Exemplo n.º 11
0
class BaseIPythonApplication(Application):

    name = Unicode(u'ipython')
    description = Unicode(u'IPython: an enhanced interactive Python shell.')
    version = Unicode(release.version)

    aliases = Dict(base_aliases)
    flags = Dict(base_flags)
    classes = List([ProfileDir])

    # enable `load_subconfig('cfg.py', profile='name')`
    python_config_loader_class = ProfileAwareConfigLoader

    # Track whether the config_file has changed,
    # because some logic happens only if we aren't using the default.
    config_file_specified = Set()

    config_file_name = Unicode()

    def _config_file_name_default(self):
        return self.name.replace('-', '_') + u'_config.py'

    def _config_file_name_changed(self, name, old, new):
        if new != old:
            self.config_file_specified.add(new)

    # The directory that contains IPython's builtin profiles.
    builtin_profile_dir = Unicode(
        os.path.join(get_ipython_package_dir(), u'config', u'profile',
                     u'default'))

    config_file_paths = List(Unicode())

    def _config_file_paths_default(self):
        return [py3compat.getcwd()]

    extra_config_file = Unicode(config=True,
                                help="""Path to an extra config file to load.
    
    If specified, load this config file in addition to any other IPython config.
    """)

    def _extra_config_file_changed(self, name, old, new):
        try:
            self.config_files.remove(old)
        except ValueError:
            pass
        self.config_file_specified.add(new)
        self.config_files.append(new)

    profile = Unicode(u'default',
                      config=True,
                      help="""The IPython profile to use.""")

    def _profile_changed(self, name, old, new):
        self.builtin_profile_dir = os.path.join(get_ipython_package_dir(),
                                                u'config', u'profile', new)

    ipython_dir = Unicode(config=True,
                          help="""
        The name of the IPython directory. This directory is used for logging
        configuration (through profiles), history storage, etc. The default
        is usually $HOME/.ipython. This option can also be specified through
        the environment variable IPYTHONDIR.
        """)

    def _ipython_dir_default(self):
        d = get_ipython_dir()
        self._ipython_dir_changed('ipython_dir', d, d)
        return d

    _in_init_profile_dir = False
    profile_dir = Instance(ProfileDir, allow_none=True)

    def _profile_dir_default(self):
        # avoid recursion
        if self._in_init_profile_dir:
            return
        # profile_dir requested early, force initialization
        self.init_profile_dir()
        return self.profile_dir

    overwrite = Bool(
        False,
        config=True,
        help="""Whether to overwrite existing config files when copying""")
    auto_create = Bool(
        False,
        config=True,
        help="""Whether to create profile dir if it doesn't exist""")

    config_files = List(Unicode())

    def _config_files_default(self):
        return [self.config_file_name]

    copy_config_files = Bool(
        False,
        config=True,
        help="""Whether to install the default config files into the profile dir.
        If a new profile is being created, and IPython contains config files for that
        profile, then they will be staged into the new directory.  Otherwise,
        default config files will be automatically generated.
        """)

    verbose_crash = Bool(
        False,
        config=True,
        help=
        """Create a massive crash report when IPython encounters what may be an
        internal error.  The default is to append a short message to the
        usual traceback""")

    # The class to use as the crash handler.
    crash_handler_class = Type(crashhandler.CrashHandler)

    @catch_config_error
    def __init__(self, **kwargs):
        super(BaseIPythonApplication, self).__init__(**kwargs)
        # ensure current working directory exists
        try:
            py3compat.getcwd()
        except:
            # exit if cwd doesn't exist
            self.log.error("Current working directory doesn't exist.")
            self.exit(1)

    #-------------------------------------------------------------------------
    # Various stages of Application creation
    #-------------------------------------------------------------------------

    deprecated_subcommands = {}

    def initialize_subcommand(self, subc, argv=None):
        if subc in self.deprecated_subcommands:
            import time
            self.log.warning(
                "Subcommand `ipython {sub}` is deprecated and will be removed "
                "in future versions.".format(sub=subc))
            self.log.warning(
                "You likely want to use `jupyter {sub}`... continue "
                "in 5 sec. Press Ctrl-C to quit now.".format(sub=subc))
            try:
                time.sleep(5)
            except KeyboardInterrupt:
                sys.exit(1)
        return super(BaseIPythonApplication,
                     self).initialize_subcommand(subc, argv)

    def init_crash_handler(self):
        """Create a crash handler, typically setting sys.excepthook to it."""
        self.crash_handler = self.crash_handler_class(self)
        sys.excepthook = self.excepthook

        def unset_crashhandler():
            sys.excepthook = sys.__excepthook__

        atexit.register(unset_crashhandler)

    def excepthook(self, etype, evalue, tb):
        """this is sys.excepthook after init_crashhandler
        
        set self.verbose_crash=True to use our full crashhandler, instead of
        a regular traceback with a short message (crash_handler_lite)
        """

        if self.verbose_crash:
            return self.crash_handler(etype, evalue, tb)
        else:
            return crashhandler.crash_handler_lite(etype, evalue, tb)

    def _ipython_dir_changed(self, name, old, new):
        if old is not Undefined:
            str_old = py3compat.cast_bytes_py2(os.path.abspath(old),
                                               sys.getfilesystemencoding())
            if str_old in sys.path:
                sys.path.remove(str_old)
        str_path = py3compat.cast_bytes_py2(os.path.abspath(new),
                                            sys.getfilesystemencoding())
        sys.path.append(str_path)
        ensure_dir_exists(new)
        readme = os.path.join(new, 'README')
        readme_src = os.path.join(get_ipython_package_dir(), u'config',
                                  u'profile', 'README')
        if not os.path.exists(readme) and os.path.exists(readme_src):
            shutil.copy(readme_src, readme)
        for d in ('extensions', 'nbextensions'):
            path = os.path.join(new, d)
            try:
                ensure_dir_exists(path)
            except OSError as e:
                # this will not be EEXIST
                self.log.error("couldn't create path %s: %s", path, e)
        self.log.debug("IPYTHONDIR set to: %s" % new)

    def load_config_file(self, suppress_errors=True):
        """Load the config file.

        By default, errors in loading config are handled, and a warning
        printed on screen. For testing, the suppress_errors option is set
        to False, so errors will make tests fail.
        """
        self.log.debug("Searching path %s for config files",
                       self.config_file_paths)
        base_config = 'ipython_config.py'
        self.log.debug("Attempting to load config file: %s" % base_config)
        try:
            Application.load_config_file(self,
                                         base_config,
                                         path=self.config_file_paths)
        except ConfigFileNotFound:
            # ignore errors loading parent
            self.log.debug("Config file %s not found", base_config)
            pass

        for config_file_name in self.config_files:
            if not config_file_name or config_file_name == base_config:
                continue
            self.log.debug("Attempting to load config file: %s" %
                           self.config_file_name)
            try:
                Application.load_config_file(self,
                                             config_file_name,
                                             path=self.config_file_paths)
            except ConfigFileNotFound:
                # Only warn if the default config file was NOT being used.
                if config_file_name in self.config_file_specified:
                    msg = self.log.warning
                else:
                    msg = self.log.debug
                msg("Config file not found, skipping: %s", config_file_name)
            except Exception:
                # For testing purposes.
                if not suppress_errors:
                    raise
                self.log.warning("Error loading config file: %s" %
                                 self.config_file_name,
                                 exc_info=True)

    def init_profile_dir(self):
        """initialize the profile dir"""
        self._in_init_profile_dir = True
        if self.profile_dir is not None:
            # already ran
            return
        if 'ProfileDir.location' not in self.config:
            # location not specified, find by profile name
            try:
                p = ProfileDir.find_profile_dir_by_name(
                    self.ipython_dir, self.profile, self.config)
            except ProfileDirError:
                # not found, maybe create it (always create default profile)
                if self.auto_create or self.profile == 'default':
                    try:
                        p = ProfileDir.create_profile_dir_by_name(
                            self.ipython_dir, self.profile, self.config)
                    except ProfileDirError:
                        self.log.fatal("Could not create profile: %r" %
                                       self.profile)
                        self.exit(1)
                    else:
                        self.log.info("Created profile dir: %r" % p.location)
                else:
                    self.log.fatal("Profile %r not found." % self.profile)
                    self.exit(1)
            else:
                self.log.debug("Using existing profile dir: %r" % p.location)
        else:
            location = self.config.ProfileDir.location
            # location is fully specified
            try:
                p = ProfileDir.find_profile_dir(location, self.config)
            except ProfileDirError:
                # not found, maybe create it
                if self.auto_create:
                    try:
                        p = ProfileDir.create_profile_dir(
                            location, self.config)
                    except ProfileDirError:
                        self.log.fatal(
                            "Could not create profile directory: %r" %
                            location)
                        self.exit(1)
                    else:
                        self.log.debug("Creating new profile dir: %r" %
                                       location)
                else:
                    self.log.fatal("Profile directory %r not found." %
                                   location)
                    self.exit(1)
            else:
                self.log.info("Using existing profile dir: %r" % location)
            # if profile_dir is specified explicitly, set profile name
            dir_name = os.path.basename(p.location)
            if dir_name.startswith('profile_'):
                self.profile = dir_name[8:]

        self.profile_dir = p
        self.config_file_paths.append(p.location)
        self._in_init_profile_dir = False

    def init_config_files(self):
        """[optionally] copy default config files into profile dir."""
        self.config_file_paths.extend(SYSTEM_CONFIG_DIRS)
        # copy config files
        path = self.builtin_profile_dir
        if self.copy_config_files:
            src = self.profile

            cfg = self.config_file_name
            if path and os.path.exists(os.path.join(path, cfg)):
                self.log.warning(
                    "Staging %r from %s into %r [overwrite=%s]" %
                    (cfg, src, self.profile_dir.location, self.overwrite))
                self.profile_dir.copy_config_file(cfg,
                                                  path=path,
                                                  overwrite=self.overwrite)
            else:
                self.stage_default_config_file()
        else:
            # Still stage *bundled* config files, but not generated ones
            # This is necessary for `ipython profile=sympy` to load the profile
            # on the first go
            files = glob.glob(os.path.join(path, '*.py'))
            for fullpath in files:
                cfg = os.path.basename(fullpath)
                if self.profile_dir.copy_config_file(cfg,
                                                     path=path,
                                                     overwrite=False):
                    # file was copied
                    self.log.warning(
                        "Staging bundled %s from %s into %r" %
                        (cfg, self.profile, self.profile_dir.location))

    def stage_default_config_file(self):
        """auto generate default config file, and stage it into the profile."""
        s = self.generate_config_file()
        fname = os.path.join(self.profile_dir.location, self.config_file_name)
        if self.overwrite or not os.path.exists(fname):
            self.log.warning("Generating default config file: %r" % (fname))
            with open(fname, 'w') as f:
                f.write(s)

    @catch_config_error
    def initialize(self, argv=None):
        # don't hook up crash handler before parsing command-line
        self.parse_command_line(argv)
        self.init_crash_handler()
        if self.subapp is not None:
            # stop here if subapp is taking over
            return
        cl_config = self.config
        self.init_profile_dir()
        self.init_config_files()
        self.load_config_file()
        # enforce cl-opts override configfile opts:
        self.update_config(cl_config)
Exemplo n.º 12
0
class LocalProcessSpawner(Spawner):
    """
    A Spawner that uses `subprocess.Popen` to start single-user servers as local processes.

    Requires local UNIX users matching the authenticated users to exist.
    Does not work on Windows.

    This is the default spawner for JupyterHub.

    Note: This spawner does not implement CPU / memory guarantees and limits.
    """

    interrupt_timeout = Integer(10,
                                help="""
        Seconds to wait for single-user server process to halt after SIGINT.

        If the process has not exited cleanly after this many seconds, a SIGTERM is sent.
        """).tag(config=True)

    term_timeout = Integer(5,
                           help="""
        Seconds to wait for single-user server process to halt after SIGTERM.

        If the process does not exit cleanly after this many seconds of SIGTERM, a SIGKILL is sent.
        """).tag(config=True)

    kill_timeout = Integer(5,
                           help="""
        Seconds to wait for process to halt after SIGKILL before giving up.

        If the process does not exit cleanly after this many seconds of SIGKILL, it becomes a zombie
        process. The hub process will log a warning and then give up.
        """).tag(config=True)

    popen_kwargs = Dict(help="""Extra keyword arguments to pass to Popen

        when spawning single-user servers.

        For example::

            popen_kwargs = dict(shell=True)

        """).tag(config=True)
    shell_cmd = Command(minlen=0,
                        help="""Specify a shell command to launch.

        The single-user command will be appended to this list,
        so it sould end with `-c` (for bash) or equivalent.

        For example::

            c.LocalProcessSpawner.shell_cmd = ['bash', '-l', '-c']

        to launch with a bash login shell, which would set up the user's own complete environment.

        .. warning::

            Using shell_cmd gives users control over PATH, etc.,
            which could change what the jupyterhub-singleuser launch command does.
            Only use this for trusted users.
        """).tag(config=True)

    proc = Instance(Popen,
                    allow_none=True,
                    help="""
        The process representing the single-user server process spawned for current user.

        Is None if no process has been spawned yet.
        """)
    pid = Integer(0,
                  help="""
        The process id (pid) of the single-user server process spawned for current user.
        """)

    def make_preexec_fn(self, name):
        """
        Return a function that can be used to set the user id of the spawned process to user with name `name`

        This function can be safely passed to `preexec_fn` of `Popen`
        """
        return set_user_setuid(name)

    def load_state(self, state):
        """Restore state about spawned single-user server after a hub restart.

        Local processes only need the process id.
        """
        super(LocalProcessSpawner, self).load_state(state)
        if 'pid' in state:
            self.pid = state['pid']

    def get_state(self):
        """Save state that is needed to restore this spawner instance after a hub restore.

        Local processes only need the process id.
        """
        state = super(LocalProcessSpawner, self).get_state()
        if self.pid:
            state['pid'] = self.pid
        return state

    def clear_state(self):
        """Clear stored state about this spawner (pid)"""
        super(LocalProcessSpawner, self).clear_state()
        self.pid = 0

    def user_env(self, env):
        """Augment environment of spawned process with user specific env variables."""
        import pwd
        unixname = self.user.name.replace('@', '')
        unixname = unixname.replace('.', '')
        #env['USER'] = self.user.name
        env['USER'] = unixname
        #home = pwd.getpwnam(self.user.name).pw_dir
        home = pwd.getpwnam(unixname).pw_dir
        #shell = pwd.getpwnam(self.user.name).pw_shell
        shell = pwd.getpwnam(unixname).pw_shell
        # These will be empty if undefined,
        # in which case don't set the env:
        if home:
            env['HOME'] = home
        if shell:
            env['SHELL'] = shell
        return env

    def get_env(self):
        """Get the complete set of environment variables to be set in the spawned process."""
        env = super().get_env()
        env = self.user_env(env)
        return env

    async def start(self):
        """Start the single-user server."""
        self.port = random_port()
        cmd = []
        env = self.get_env()

        cmd.extend(self.cmd)
        cmd.extend(self.get_args())

        if self.shell_cmd:
            # using shell_cmd (e.g. bash -c),
            # add our cmd list as the last (single) argument:
            cmd = self.shell_cmd + [' '.join(pipes.quote(s) for s in cmd)]

        self.log.info("Spawning %s", ' '.join(pipes.quote(s) for s in cmd))
        unixname = self.user.name.replace('@', '')
        unixname = unixname.replace('.', '')
        popen_kwargs = dict(
            #preexec_fn=self.make_preexec_fn(self.user.name),
            preexec_fn=self.make_preexec_fn(unixname),
            start_new_session=True,  # don't forward signals
        )
        popen_kwargs.update(self.popen_kwargs)
        # don't let user config override env
        popen_kwargs['env'] = env
        try:
            self.proc = Popen(cmd, **popen_kwargs)
        except PermissionError:
            # use which to get abspath
            script = shutil.which(cmd[0]) or cmd[0]
            self.log.error(
                "Permission denied trying to run %r. Does %s have access to this file?",
                #script, self.user.name,
                script,
                self.user.unixname,
            )
            raise

        self.pid = self.proc.pid

        if self.__class__ is not LocalProcessSpawner:
            # subclasses may not pass through return value of super().start,
            # relying on deprecated 0.6 way of setting ip, port,
            # so keep a redundant copy here for now.
            # A deprecation warning will be shown if the subclass
            # does not return ip, port.
            if self.ip:
                self.server.ip = self.ip
            self.server.port = self.port
            self.db.commit()
        return (self.ip or '127.0.0.1', self.port)

    async def poll(self):
        """Poll the spawned process to see if it is still running.

        If the process is still running, we return None. If it is not running,
        we return the exit code of the process if we have access to it, or 0 otherwise.
        """
        # if we started the process, poll with Popen
        if self.proc is not None:
            status = self.proc.poll()
            if status is not None:
                # clear state if the process is done
                self.clear_state()
            return status

        # if we resumed from stored state,
        # we don't have the Popen handle anymore, so rely on self.pid
        if not self.pid:
            # no pid, not running
            self.clear_state()
            return 0

        # send signal 0 to check if PID exists
        # this doesn't work on Windows, but that's okay because we don't support Windows.
        alive = await self._signal(0)
        if not alive:
            self.clear_state()
            return 0
        else:
            return None

    async def _signal(self, sig):
        """Send given signal to a single-user server's process.

        Returns True if the process still exists, False otherwise.

        The hub process is assumed to have enough privileges to do this (e.g. root).
        """
        try:
            os.kill(self.pid, sig)
        except OSError as e:
            if e.errno == errno.ESRCH:
                return False  # process is gone
            else:
                raise
        return True  # process exists

    async def stop(self, now=False):
        """Stop the single-user server process for the current user.

        If `now` is False (default), shutdown the server as gracefully as possible,
        e.g. starting with SIGINT, then SIGTERM, then SIGKILL.
        If `now` is True, terminate the server immediately.

        The coroutine should return when the process is no longer running.
        """
        if not now:
            status = await self.poll()
            if status is not None:
                return
            self.log.debug("Interrupting %i", self.pid)
            await self._signal(signal.SIGINT)
            await self.wait_for_death(self.interrupt_timeout)

        # clean shutdown failed, use TERM
        status = await self.poll()
        if status is not None:
            return
        self.log.debug("Terminating %i", self.pid)
        await self._signal(signal.SIGTERM)
        await self.wait_for_death(self.term_timeout)

        # TERM failed, use KILL
        status = await self.poll()
        if status is not None:
            return
        self.log.debug("Killing %i", self.pid)
        await self._signal(signal.SIGKILL)
        await self.wait_for_death(self.kill_timeout)

        status = await self.poll()
        if status is None:
            # it all failed, zombie process
            self.log.warning("Process %i never died", self.pid)
Exemplo n.º 13
0
class InlineBackend(InlineBackendConfig):
    """An object to store configuration of the inline backend."""

    # The typical default figure size is too large for inline use,
    # so we shrink the figure size to 6x4, and tweak fonts to
    # make that fit.
    rc = Dict(
        {
            'figure.figsize': (6.0, 4.0),
            # play nicely with white background in the Qt and notebook frontend
            'figure.facecolor': (1, 1, 1, 0),
            'figure.edgecolor': (1, 1, 1, 0),
            # 12pt labels get cutoff on 6x4 logplots, so use 10pt.
            'font.size': 10,
            # 72 dpi matches SVG/qtconsole
            # this only affects PNG export, as SVG has no dpi setting
            'figure.dpi': 72,
            # 10pt still needs a little more room on the xlabel:
            'figure.subplot.bottom': .125
        },
        help="""Subset of matplotlib rcParams that should be different for the
        inline backend."""
    ).tag(config=True)

    figure_formats = Set(
        {'png'},
        help="""A set of figure formats to enable: 'png',
                'retina', 'jpeg', 'svg', 'pdf'.""").tag(config=True)

    def _update_figure_formatters(self):
        if self.shell is not None:
            from IPython.core.pylabtools import select_figure_formats
            select_figure_formats(self.shell, self.figure_formats, **self.print_figure_kwargs)

    def _figure_formats_changed(self, name, old, new):
        if 'jpg' in new or 'jpeg' in new:
            if not pil_available():
                raise TraitError("Requires PIL/Pillow for JPG figures")
        self._update_figure_formatters()

    figure_format = Unicode(help="""The figure format to enable (deprecated
                                         use `figure_formats` instead)""").tag(config=True)

    def _figure_format_changed(self, name, old, new):
        if new:
            self.figure_formats = {new}

    print_figure_kwargs = Dict(
        {'bbox_inches': 'tight'},
        help="""Extra kwargs to be passed to fig.canvas.print_figure.

        Logical examples include: bbox_inches, quality (for jpeg figures), etc.
        """
    ).tag(config=True)
    _print_figure_kwargs_changed = _update_figure_formatters

    close_figures = Bool(
        True,
        help="""Close all figures at the end of each cell.

        When True, ensures that each cell starts with no active figures, but it
        also means that one must keep track of references in order to edit or
        redraw figures in subsequent cells. This mode is ideal for the notebook,
        where residual plots from other cells might be surprising.

        When False, one must call figure() to create new figures. This means
        that gcf() and getfigs() can reference figures created in other cells,
        and the active figure can continue to be edited with pylab/pyplot
        methods that reference the current active figure. This mode facilitates
        iterative editing of figures, and behaves most consistently with
        other matplotlib backends, but figure barriers between cells must
        be explicit.
        """).tag(config=True)

    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
                     allow_none=True)
Exemplo n.º 14
0
class DaskGateway(Application):
    """A gateway for managing dask clusters across multiple users"""

    name = "dask-gateway-server"
    version = VERSION

    description = """Start a Dask Gateway server"""

    examples = """

    Start the server with config file ``config.py``

        dask-gateway-server -f config.py
    """

    subcommands = {
        "generate-config": (
            "dask_gateway_server.app.GenerateConfig",
            "Generate a default config file",
        ),
        "scheduler-proxy": (
            "dask_gateway_server.proxy.core.SchedulerProxyApp",
            "Start the scheduler proxy",
        ),
        "web-proxy": (
            "dask_gateway_server.proxy.core.WebProxyApp",
            "Start the web proxy",
        ),
    }

    aliases = {
        "log-level": "DaskGateway.log_level",
        "f": "DaskGateway.config_file",
        "config": "DaskGateway.config_file",
    }

    config_file = Unicode(
        "dask_gateway_config.py", help="The config file to load", config=True
    )

    # Fail if the config file errors
    raise_config_file_errors = True

    scheduler_proxy_class = Type(
        "dask_gateway_server.proxy.SchedulerProxy",
        help="The gateway scheduler proxy class to use",
    )

    web_proxy_class = Type(
        "dask_gateway_server.proxy.WebProxy", help="The gateway web proxy class to use"
    )

    authenticator_class = Type(
        "dask_gateway_server.auth.DummyAuthenticator",
        klass="dask_gateway_server.auth.Authenticator",
        help="The gateway authenticator class to use",
        config=True,
    )

    cluster_manager_class = Type(
        "dask_gateway_server.managers.local.UnsafeLocalClusterManager",
        klass="dask_gateway_server.managers.ClusterManager",
        help="The gateway cluster manager class to use",
        config=True,
    )

    cluster_manager_options = Instance(
        Options,
        args=(),
        help="""
        User options for configuring the cluster manager.

        Allows users to specify configuration overrides when creating a new
        cluster manager. See the documentation for more information:

        :doc:`cluster-options`.
        """,
        config=True,
    )

    public_url = Unicode(
        "http://:8000",
        help="The public facing URL of the whole Dask Gateway application",
        config=True,
    )

    public_connect_url = Unicode(
        help="""
        The address that the public URL can be connected to.

        Useful if the address the web proxy should listen at is different than
        the address it's reachable at (by e.g. the scheduler/workers).

        Defaults to ``public_url``.
        """,
        config=True,
    )

    gateway_url = Unicode(help="The URL that Dask clients will connect to", config=True)

    private_url = Unicode(
        "http://127.0.0.1:0",
        help="""
        The gateway's private URL, used for internal communication.

        This must be reachable from the web proxy, but shouldn't be publicly
        accessible (if possible). Default is ``http://127.0.0.1:{random-port}``.
        """,
        config=True,
    )

    private_connect_url = Unicode(
        help="""
        The address that the private URL can be connected to.

        Useful if the address the gateway should listen at is different than
        the address it's reachable at (by e.g. the web proxy).

        Defaults to ``private_url``.
        """,
        config=True,
    )

    @validate(
        "public_url",
        "public_connect_url",
        "gateway_url",
        "private_url",
        "private_connect_url",
    )
    def _validate_url(self, proposal):
        url = proposal.value
        name = proposal.trait.name
        scheme = urlparse(url).scheme
        if name.startswith("gateway"):
            if scheme != "tls":
                raise ValueError("'gateway_url' must be a tls url, got %s" % url)
        else:
            if scheme not in {"http", "https"}:
                raise ValueError("%r must be an http/https url, got %s" % (name, url))
        return url

    tls_key = Unicode(
        "",
        help="""Path to TLS key file for the public url of the web proxy.

        When setting this, you should also set tls_cert.
        """,
        config=True,
    )

    tls_cert = Unicode(
        "",
        help="""Path to TLS certificate file for the public url of the web proxy.

        When setting this, you should also set tls_key.
        """,
        config=True,
    )

    cookie_secret = Bytes(
        help="""The cookie secret to use to encrypt cookies.

        A 32 byte hex-encoded random string. Commonly created with the
        ``openssl`` CLI:

        .. code-block:: shell

            $ openssl rand -hex 32

        Loaded from the DASK_GATEWAY_COOKIE_SECRET environment variable by
        default.
        """,
        config=True,
    )

    @default("cookie_secret")
    def _cookie_secret_default(self):
        secret = os.environb.get(b"DASK_GATEWAY_COOKIE_SECRET", b"")
        if not secret:
            self.log.debug("Generating new cookie secret")
            secret = os.urandom(32)
        return secret

    @validate("cookie_secret")
    def _cookie_secret_validate(self, proposal):
        if len(proposal["value"]) != 32:
            raise ValueError(
                "Cookie secret is %d bytes, it must be "
                "32 bytes" % len(proposal["value"])
            )
        return proposal["value"]

    cookie_max_age_days = Float(
        7,
        help="""Number of days for a login cookie to be valid.
        Default is one week.
        """,
        config=True,
    )

    stop_clusters_on_shutdown = Bool(
        True,
        help="""
        Whether to stop active clusters on gateway shutdown.

        If true, all active clusters will be stopped before shutting down the
        gateway. Set to False to leave active clusters running.
        """,
        config=True,
    )

    @validate("stop_clusters_on_shutdown")
    def _stop_clusters_on_shutdown_validate(self, proposal):
        if not proposal.value and is_in_memory_db(self.db_url):
            raise ValueError(
                "When using an in-memory database, `stop_clusters_on_shutdown` "
                "must be True"
            )
        return proposal.value

    check_cluster_timeout = Float(
        10,
        help="""
        Timeout (in seconds) before deciding a cluster is no longer active.

        When the gateway restarts, any clusters still marked as active have
        their status checked. This timeout sets the max time we allocate for
        checking a cluster's status before deciding that the cluster is no
        longer active.
        """,
        config=True,
    )

    db_url = Unicode(
        "sqlite:///:memory:",
        help="""
        The URL for the database. Default is in-memory only.

        If not in-memory, ``db_encrypt_keys`` must also be set.
        """,
        config=True,
    )

    db_encrypt_keys = List(
        help="""
        A list of keys to use to encrypt private data in the database. Can also
        be set by the environment variable ``DASK_GATEWAY_ENCRYPT_KEYS``, where
        the value is a ``;`` delimited string of encryption keys.

        Each key should be a base64-encoded 32 byte value, and should be
        cryptographically random. Lacking other options, openssl can be used to
        generate a single key via:

        .. code-block:: shell

            $ openssl rand -base64 32

        A single key is valid, multiple keys can be used to support key rotation.
        """,
        config=True,
    )

    @default("db_encrypt_keys")
    def _db_encrypt_keys_default(self):
        keys = os.environb.get(b"DASK_GATEWAY_ENCRYPT_KEYS", b"").strip()
        if not keys:
            return []
        return [normalize_encrypt_key(k) for k in keys.split(b";") if k.strip()]

    @validate("db_encrypt_keys")
    def _db_encrypt_keys_validate(self, proposal):
        if not proposal.value and not is_in_memory_db(self.db_url):
            raise ValueError(
                "Must configure `db_encrypt_keys`/`DASK_GATEWAY_ENCRYPT_KEYS` "
                "when not using an in-memory database"
            )
        return [normalize_encrypt_key(k) for k in proposal.value]

    db_debug = Bool(
        False, help="If True, all database operations will be logged", config=True
    )

    db_cleanup_period = Float(
        600,
        help="""
        Time (in seconds) between database cleanup tasks.

        This sets how frequently old records are removed from the database.
        This shouldn't be too small (to keep the overhead low), but should be
        smaller than ``db_record_max_age`` (probably by an order of magnitude).
        """,
        config=True,
    )

    db_cluster_max_age = Float(
        3600 * 24,
        help="""
        Max time (in seconds) to keep around records of completed clusters.

        Every ``db_cleanup_period``, completed clusters older than
        ``db_cluster_max_age`` are removed from the database.
        """,
        config=True,
    )

    temp_dir = Unicode(
        None,
        help="""
        The directory to use when creating temporary runtime files.

        Defaults to the platform's temporary directory, see
        https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir for
        more information.
        """,
        config=True,
        allow_none=True,
    )

    _log_formatter_cls = LogFormatter

    classes = List([ClusterManager, Authenticator, WebProxy, SchedulerProxy])

    @catch_config_error
    def initialize(self, argv=None):
        super().initialize(argv)
        if self.subapp is not None:
            return
        self.log.info("Starting dask-gateway-server - version %s", VERSION)
        self.load_config_file(self.config_file)
        self.log.info("Cluster manager: %r", classname(self.cluster_manager_class))
        self.log.info("Authenticator: %r", classname(self.authenticator_class))
        self.init_logging()
        self.init_asyncio()
        self.init_server_urls()
        self.init_scheduler_proxy()
        self.init_web_proxy()
        self.init_authenticator()
        self.init_user_limits()
        self.init_database()
        self.init_tornado_application()

    def init_logging(self):
        # Prevent double log messages from tornado
        self.log.propagate = False

        # hook up tornado's loggers to our app handlers
        from tornado.log import app_log, access_log, gen_log

        for log in (app_log, access_log, gen_log):
            log.name = self.log.name
            log.handlers[:] = []
        logger = logging.getLogger("tornado")
        logger.handlers[:] = []
        logger.propagate = True
        logger.parent = self.log
        logger.setLevel(self.log.level)

    def init_asyncio(self):
        self.task_pool = TaskPool()

    def init_server_urls(self):
        """Initialize addresses from configuration"""
        self.public_urls = ServerUrls(self.public_url, self.public_connect_url)
        self.private_urls = ServerUrls(self.private_url, self.private_connect_url)
        if not self.gateway_url:
            gateway_url = f"tls://{self.public_urls.bind_host}:8786"
        else:
            gateway_url = self.gateway_url
        self.gateway_urls = ServerUrls(gateway_url)
        # Additional common url
        self.api_url = self.public_urls.connect_url + "/gateway/api"

    def init_scheduler_proxy(self):
        self.scheduler_proxy = self.scheduler_proxy_class(
            parent=self, log=self.log, public_urls=self.gateway_urls
        )

    def init_web_proxy(self):
        self.web_proxy = self.web_proxy_class(
            parent=self,
            log=self.log,
            public_urls=self.public_urls,
            tls_cert=self.tls_cert,
            tls_key=self.tls_key,
        )

    def init_authenticator(self):
        self.authenticator = self.authenticator_class(parent=self, log=self.log)

    def init_user_limits(self):
        self.user_limits = UserLimits(parent=self, log=self.log)

    def init_database(self):
        self.db = DataManager(
            url=self.db_url, echo=self.db_debug, encrypt_keys=self.db_encrypt_keys
        )

    def init_tornado_application(self):
        self.handlers = list(handlers.default_handlers)
        self.tornado_application = web.Application(
            self.handlers,
            log=self.log,
            gateway=self,
            authenticator=self.authenticator,
            cookie_secret=self.cookie_secret,
            cookie_max_age_days=self.cookie_max_age_days,
        )

    async def start_async(self):
        self.init_signal()
        await self.start_scheduler_proxy()
        await self.start_web_proxy()
        await self.load_database_state()
        await self.start_tornado_application()

    async def start_scheduler_proxy(self):
        await self.scheduler_proxy.start()

    async def start_web_proxy(self):
        await self.web_proxy.start()

    def create_cluster_manager(self, options):
        config = self.cluster_manager_options.get_configuration(options)
        return self.cluster_manager_class(
            parent=self,
            log=self.log,
            task_pool=self.task_pool,
            temp_dir=self.temp_dir,
            api_url=self.api_url,
            **config,
        )

    def init_cluster_manager(self, manager, cluster):
        manager.username = cluster.user.name
        manager.cluster_name = cluster.name
        manager.api_token = cluster.token
        manager.tls_cert = cluster.tls_cert
        manager.tls_key = cluster.tls_key

    async def load_database_state(self):
        self.db.load_database_state()

        active_clusters = list(self.db.active_clusters())
        if active_clusters:
            self.log.info(
                "Gateway was stopped with %d active clusters, "
                "checking cluster status...",
                len(active_clusters),
            )

            tasks = (self.check_cluster(c) for c in active_clusters)
            results = await asyncio.gather(*tasks, return_exceptions=True)

            n_clusters = 0
            for c, r in zip(active_clusters, results):
                if isinstance(r, Exception):
                    self.log.error(
                        "Error while checking status of cluster %s", c.name, exc_info=r
                    )
                elif r:
                    n_clusters += 1

            self.log.info(
                "All clusters have been checked, there are %d active clusters",
                n_clusters,
            )

        self.task_pool.create_background_task(self.cleanup_database())

    async def cleanup_database(self):
        while True:
            try:
                n = self.db.cleanup_expired(self.db_cluster_max_age)
            except Exception as exc:
                self.log.error(
                    "Error while cleaning expired database records", exc_info=exc
                )
            else:
                self.log.debug("Removed %d expired clusters from the database", n)
            await asyncio.sleep(self.db_cleanup_period)

    async def check_cluster(self, cluster):
        cluster.manager = self.create_cluster_manager(cluster.options)
        self.init_cluster_manager(cluster.manager, cluster)

        if cluster.status == ClusterStatus.RUNNING:
            client = AsyncHTTPClient()
            url = "%s/api/status" % cluster.api_address
            req = HTTPRequest(
                url, method="GET", headers={"Authorization": "token %s" % cluster.token}
            )
            try:
                resp = await asyncio.wait_for(
                    client.fetch(req), timeout=self.check_cluster_timeout
                )
                workers = json.loads(resp.body.decode("utf8", "replace"))["workers"]
                running = True
            except asyncio.CancelledError:
                raise
            except Exception:
                running = False
                workers = []
        else:
            # Gateway was stopped before cluster fully started.
            running = False
            workers = []

        if running:
            # Cluster is running, update our state to match
            await self.add_cluster_to_proxies(cluster)

            # Update our set of workers to match
            actual_workers = set(workers)
            to_stop = []
            for w in cluster.active_workers():
                if w.name in actual_workers:
                    self.mark_worker_running(cluster, w)
                else:
                    to_stop.append(w)

            tasks = (self.stop_worker(cluster, w, failed=True) for w in to_stop)
            await asyncio.gather(*tasks, return_exceptions=False)

            # Start the periodic monitor
            self.start_cluster_status_monitor(cluster)
        else:
            # Cluster is not available, shut it down
            await self.stop_cluster(cluster, failed=True)

        return running

    async def start_tornado_application(self):
        self.http_server = self.tornado_application.listen(
            self.private_urls.bind_port, address=self.private_urls.bind_host
        )
        self.log.info("Gateway private API serving at %s", self.private_urls.bind_url)
        await self.web_proxy.add_route(
            self.public_urls.connect.path + "/gateway/", self.private_urls.connect_url
        )
        self.log.info("Dask-Gateway started successfully!")
        for name, urls in [
            ("Public address", self.public_urls._to_log),
            ("Proxy address", self.gateway_urls._to_log),
        ]:
            if len(urls) == 2:
                self.log.info("- %s at %s or %s", name, *urls)
            else:
                self.log.info("- %s at %s", name, *urls)

    async def start_or_exit(self):
        try:
            await self.start_async()
        except Exception:
            self.log.critical("Failed to start gateway, shutting down", exc_info=True)
            await self.stop_async(stop_event_loop=False)
            self.exit(1)

    def start(self):
        if self.subapp is not None:
            return self.subapp.start()
        AsyncIOMainLoop().install()
        loop = IOLoop.current()
        loop.add_callback(self.start_or_exit)
        try:
            loop.start()
        except KeyboardInterrupt:
            print("\nInterrupted")

    def init_signal(self):
        loop = asyncio.get_event_loop()
        for s in (signal.SIGTERM, signal.SIGINT):
            loop.add_signal_handler(s, self.handle_shutdown_signal, s)

    def handle_shutdown_signal(self, sig):
        self.log.info("Received signal %s, initiating shutdown...", sig.name)
        asyncio.ensure_future(self.stop_async())

    async def _stop_async(self, timeout=5):
        # Stop the server to prevent new requests
        if hasattr(self, "http_server"):
            self.http_server.stop()

        # If requested, shutdown any active clusters
        if self.stop_clusters_on_shutdown:
            tasks = {
                asyncio.ensure_future(self.stop_cluster(c, failed=True)): c
                for c in self.db.active_clusters()
            }
            if tasks:
                self.log.info("Stopping all active clusters...")
                done, pending = await asyncio.wait(tasks.keys())
                for f in done:
                    try:
                        await f
                    except Exception as exc:
                        cluster = tasks[f]
                        self.log.error(
                            "Failed to stop cluster %s for user %s",
                            cluster.name,
                            cluster.user.name,
                            exc_info=exc,
                        )
        else:
            self.log.info("Leaving any active clusters running")

        if hasattr(self, "task_pool"):
            await self.task_pool.close(timeout=timeout)

        # Shutdown the proxies
        if hasattr(self, "scheduler_proxy"):
            self.scheduler_proxy.stop()
        if hasattr(self, "web_proxy"):
            self.web_proxy.stop()

    async def stop_async(self, timeout=5, stop_event_loop=True):
        try:
            await self._stop_async(timeout=timeout)
        except Exception:
            self.log.error("Error while shutting down:", exc_info=True)
        # Stop the event loop
        if stop_event_loop:
            IOLoop.current().stop()

    def start_cluster_status_monitor(self, cluster):
        cluster._status_monitor = self.task_pool.create_background_task(
            self._cluster_status_monitor(cluster)
        )

    def stop_cluster_status_monitor(self, cluster):
        if cluster._status_monitor is not None:
            cluster._status_monitor.cancel()
            cluster._status_monitor = None

    async def _cluster_status_monitor(self, cluster):
        while True:
            try:
                res = await cluster.manager.cluster_status(cluster.state)
            except asyncio.CancelledError:
                raise
            except Exception as exc:
                self.log.error(
                    "Error while checking cluster %s status", cluster.name, exc_info=exc
                )
            else:
                running, msg = res if isinstance(res, tuple) else (res, None)
                if not running:
                    if msg:
                        self.log.warning(
                            "Cluster %s stopped unexpectedly: %s", cluster.name, msg
                        )
                    else:
                        self.log.warning(
                            "Cluster %s stopped unexpectedly", cluster.name
                        )
                    self.schedule_stop_cluster(cluster, failed=True)
                    return
            await asyncio.sleep(cluster.manager.cluster_status_period)

    async def _start_cluster(self, cluster):
        self.log.info(
            "Starting cluster %s for user %s...", cluster.name, cluster.user.name
        )

        # Walk through the startup process, saving state as updates occur
        async for state in cluster.manager.start_cluster():
            self.log.debug("State update for cluster %s", cluster.name)
            self.db.update_cluster(cluster, state=state)

        # Move cluster to started
        self.db.update_cluster(cluster, status=ClusterStatus.STARTED)

    async def start_cluster(self, cluster):
        """Start the cluster.

        Returns True if successfully started, False otherwise.
        """
        try:
            async with timeout(cluster.manager.cluster_start_timeout):
                await self._start_cluster(cluster)
                self.log.info(
                    "Cluster %s has started, waiting for connection", cluster.name
                )
                self.start_cluster_status_monitor(cluster)
                addresses = await cluster._connect_future
        except asyncio.TimeoutError:
            self.log.warning(
                "Cluster %s startup timed out after %.1f seconds",
                cluster.name,
                cluster.manager.cluster_start_timeout,
            )
            return False
        except asyncio.CancelledError:
            # Catch separately to avoid in generic handler below
            raise
        except Exception as exc:
            self.log.error(
                "Error while starting cluster %s", cluster.name, exc_info=exc
            )
            return False

        scheduler_address, dashboard_address, api_address = addresses
        self.log.info("Cluster %s connected at %s", cluster.name, scheduler_address)

        # Mark cluster as running
        self.db.update_cluster(
            cluster,
            scheduler_address=scheduler_address,
            dashboard_address=dashboard_address,
            api_address=api_address,
            status=ClusterStatus.RUNNING,
        )

        # Register routes with proxies
        await self.add_cluster_to_proxies(cluster)

        return True

    async def add_cluster_to_proxies(self, cluster):
        if cluster.dashboard_address:
            await self.web_proxy.add_route(
                self.public_urls.connect.path + "/gateway/clusters/" + cluster.name,
                cluster.dashboard_address,
            )
        await self.scheduler_proxy.add_route(
            "/" + cluster.name, cluster.scheduler_address
        )

    def start_new_cluster(self, user, request):
        # Process the user provided options
        options = self.cluster_manager_options.parse_options(request)
        manager = self.create_cluster_manager(options)

        # Check if allowed to create cluster
        allowed, msg = self.user_limits.check_cluster_limits(
            user, manager.scheduler_memory, manager.scheduler_cores
        )
        if not allowed:
            raise Exception(msg)

        # Finish initializing the object states
        cluster = self.db.create_cluster(
            user, options, manager.scheduler_memory, manager.scheduler_cores
        )
        cluster.manager = manager
        self.init_cluster_manager(cluster.manager, cluster)

        # Launch the cluster startup task
        f = self.task_pool.create_task(self.start_cluster(cluster))
        f.add_done_callback(partial(self._monitor_start_cluster, cluster=cluster))
        cluster._start_future = f

        return cluster

    def _monitor_start_cluster(self, future, cluster=None):
        try:
            if future.result():
                # Startup succeeded, nothing to do
                return
        except asyncio.CancelledError:
            # Startup cancelled, cleanup is handled separately
            return
        except Exception as exc:
            self.log.error(
                "Unexpected error while starting cluster %s", cluster.name, exc_info=exc
            )

        self.schedule_stop_cluster(cluster, failed=True)

    async def stop_cluster(self, cluster, failed=False):
        if cluster.status >= ClusterStatus.STOPPING:
            return

        self.log.info("Stopping cluster %s...", cluster.name)

        # Move cluster to stopping
        self.db.update_cluster(cluster, status=ClusterStatus.STOPPING)

        # Stop the periodic monitor, if present
        self.stop_cluster_status_monitor(cluster)

        # If running, cancel running start task
        await cancel_task(cluster._start_future)

        # Remove routes from proxies if already set
        await self.web_proxy.delete_route("/gateway/clusters/" + cluster.name)
        await self.scheduler_proxy.delete_route("/" + cluster.name)

        # Shutdown individual workers if no bulk shutdown supported
        if not cluster.manager.supports_bulk_shutdown:
            tasks = (self.stop_worker(cluster, w) for w in cluster.active_workers())
            await asyncio.gather(*tasks, return_exceptions=True)

        # Shutdown the cluster
        try:
            await cluster.manager.stop_cluster(cluster.state)
        except Exception as exc:
            self.log.error("Failed to shutdown cluster %s", cluster.name, exc_info=exc)

        # If we shut the workers down in bulk, cleanup their internal state now
        if cluster.manager.supports_bulk_shutdown:
            tasks = (
                self.stop_worker(cluster, w, state_only=True)
                for w in cluster.active_workers()
            )
            await asyncio.gather(*tasks, return_exceptions=True)

        # Update the cluster status
        status = ClusterStatus.FAILED if failed else ClusterStatus.STOPPED
        self.db.update_cluster(cluster, status=status, stop_time=timestamp())
        cluster.pending.clear()

        self.log.info("Stopped cluster %s", cluster.name)

    def schedule_stop_cluster(self, cluster, failed=False):
        self.task_pool.create_task(self.stop_cluster(cluster, failed=failed))

    async def scale(self, cluster, total, is_adapt=False):
        """Scale cluster to total workers"""
        async with cluster.lock:
            n_active = len(cluster.active_workers())
            delta = total - n_active
            self.log.info(
                "Scaling cluster %s to %d workers, a delta of %d",
                cluster.name,
                total,
                delta,
            )
            if delta > 0:
                return await self.scale_up(cluster, delta, is_adapt)
            elif delta < 0:
                return await self.scale_down(cluster, total)
            else:
                if not is_adapt:
                    await self.adapt(cluster, active=False, lock=False)
                return {"added": [], "removed": [], "message": None}

    async def scale_up(self, cluster, n_start, is_adapt=False):
        # Check how many workers we're allowed_to_start
        n_allowed, msg = self.user_limits.check_scale_limits(
            cluster,
            n_start,
            memory=cluster.manager.worker_memory,
            cores=cluster.manager.worker_cores,
        )
        created = []
        for _ in range(n_allowed):
            w = self.db.create_worker(
                cluster, cluster.manager.worker_memory, cluster.manager.worker_cores
            )
            created.append(w)

        new_workers = [w.name for w in created]

        if not is_adapt:
            client = AsyncHTTPClient()
            body = json.dumps({"op": "add", "workers": new_workers})
            url = "%s/api/pending_workers" % cluster.api_address
            req = HTTPRequest(
                url,
                method="POST",
                headers={
                    "Authorization": "token %s" % cluster.token,
                    "Content-type": "application/json",
                },
                body=body,
            )
            await client.fetch(req)
            if cluster.adaptive:
                self.db.update_cluster(cluster, adaptive=False)

        for w in created:
            w._start_future = self.task_pool.create_task(self.start_worker(cluster, w))
            w._start_future.add_done_callback(
                partial(self._monitor_start_worker, worker=w, cluster=cluster)
            )

        return {"added": new_workers, "removed": [], "message": msg}

    async def _start_worker(self, cluster, worker):
        self.log.info("Starting worker %s for cluster %s...", worker.name, cluster.name)

        # Walk through the startup process, saving state as updates occur
        async for state in cluster.manager.start_worker(worker.name, cluster.state):
            self.db.update_worker(worker, state=state)

        # Move worker to started
        self.db.update_worker(worker, status=WorkerStatus.STARTED)

    async def _worker_status_monitor(self, cluster, worker):
        while True:
            try:
                res = await cluster.manager.worker_status(
                    worker.name, worker.state, cluster.state
                )
            except asyncio.CancelledError:
                raise
            except Exception as exc:
                self.log.error(
                    "Error while checking worker %s status", worker.name, exc_info=exc
                )
            else:
                running, msg = res if isinstance(res, tuple) else (res, None)
                if not running:
                    return msg
            await asyncio.sleep(cluster.manager.worker_status_period)

    async def start_worker(self, cluster, worker):
        worker_status_monitor = None
        try:
            async with timeout(cluster.manager.worker_start_timeout):
                # Submit the worker
                await self._start_worker(cluster, worker)

                self.log.info(
                    "Worker %s has started, waiting for connection", worker.name
                )

                # Wait for the worker to connect, periodically checking its status
                worker_status_monitor = asyncio.ensure_future(
                    self._worker_status_monitor(cluster, worker)
                )
                done, pending = await asyncio.wait(
                    (worker_status_monitor, worker._connect_future),
                    return_when=asyncio.FIRST_COMPLETED,
                )
        except Exception as exc:
            if worker_status_monitor is not None:
                worker_status_monitor.cancel()
            if type(exc) is asyncio.CancelledError:
                raise
            elif type(exc) is asyncio.TimeoutError:
                self.log.warning(
                    "Worker %s startup timed out after %.1f seconds",
                    worker.name,
                    cluster.manager.worker_start_timeout,
                )
            else:
                self.log.error("Error while starting worker %s", worker, exc_info=exc)
            return False

        # Check monitor for failures
        if worker_status_monitor in done:
            # Failure occurred
            msg = worker_status_monitor.result()
            if msg:
                self.log.warning(
                    "Worker %s failed during startup: %s", worker.name, msg
                )
            else:
                self.log.warning("Worker %s failed during startup", worker.name)
            return False
        else:
            worker_status_monitor.cancel()

        self.log.info("Worker %s connected to cluster %s", worker.name, cluster.name)

        # Mark worker as running
        self.mark_worker_running(cluster, worker)

        return True

    def mark_worker_running(self, cluster, worker):
        if worker.status != WorkerStatus.RUNNING:
            cluster.manager.on_worker_running(worker.name, worker.state, cluster.state)
            self.db.update_worker(worker, status=WorkerStatus.RUNNING)
            cluster.pending.discard(worker.name)

    def _monitor_start_worker(self, future, worker=None, cluster=None):
        try:
            if future.result():
                # Startup succeeded, nothing to do
                return
        except asyncio.CancelledError:
            # Startup cancelled, cleanup is handled separately
            self.log.debug("Cancelled worker %s", worker.name)
            return
        except Exception as exc:
            self.log.error(
                "Unexpected error while starting worker %s for cluster %s",
                worker.name,
                cluster.name,
                exc_info=exc,
            )

        self.schedule_stop_worker(cluster, worker, failed=True, notify=True)

    async def stop_worker(
        self, cluster, worker, failed=False, notify=False, state_only=False
    ):
        # Already stopping elsewhere, return
        if worker.status >= WorkerStatus.STOPPING:
            return

        self.log.info("Stopping worker %s for cluster %s", worker.name, cluster.name)

        # Move worker to stopping
        self.db.update_worker(worker, status=WorkerStatus.STOPPING)
        cluster.pending.discard(worker.name)

        # Cancel a pending start if needed
        await cancel_task(worker._start_future)

        # Shutdown the worker
        if not state_only:
            try:
                await cluster.manager.stop_worker(
                    worker.name, worker.state, cluster.state
                )
            except Exception as exc:
                self.log.error(
                    "Failed to shutdown worker %s for cluster %s",
                    worker.name,
                    cluster.name,
                    exc_info=exc,
                )

        # Update the worker status
        status = WorkerStatus.FAILED if failed else WorkerStatus.STOPPED
        self.db.update_worker(worker, status=status, stop_time=timestamp())

        # Maybe notify the scheduler of worker death
        if notify:
            client = AsyncHTTPClient()
            body = json.dumps({"op": "remove", "workers": [worker.name]})
            url = "%s/api/pending_workers" % cluster.api_address
            req = HTTPRequest(
                url,
                method="POST",
                headers={
                    "Authorization": "token %s" % cluster.token,
                    "Content-type": "application/json",
                },
                body=body,
            )
            await client.fetch(req)

        self.log.info("Stopped worker %s", worker.name)

    def schedule_stop_worker(self, cluster, worker, failed=False, notify=False):
        self.task_pool.create_task(
            self.stop_worker(cluster, worker, failed=failed, notify=notify)
        )

    def maybe_fail_worker(self, cluster, worker):
        # Ignore if cluster or worker isn't active (
        if (
            cluster.status != ClusterStatus.RUNNING
            or worker.status >= WorkerStatus.STOPPING
        ):
            return
        self.schedule_stop_worker(cluster, worker, failed=True)

    async def scale_down(self, cluster, target=None, workers=None):
        if target is not None:
            client = AsyncHTTPClient()
            body = json.dumps({"target": target})
            url = "%s/api/scale_down" % cluster.api_address
            req = HTTPRequest(
                url,
                method="POST",
                headers={
                    "Authorization": "token %s" % cluster.token,
                    "Content-type": "application/json",
                },
                body=body,
            )
            resp = await client.fetch(req)
            data = json.loads(resp.body.decode("utf8", "replace"))
            workers = data["workers_closed"]
            if cluster.adaptive:
                self.db.update_cluster(cluster, adaptive=False)
        else:
            assert workers is not None

        to_stop = [cluster.workers[n] for n in workers]

        self.log.debug("Stopping %d workers for cluster %s", len(to_stop), cluster.name)
        for w in to_stop:
            self.schedule_stop_worker(cluster, w)

        return {"added": [], "removed": [w.name for w in to_stop], "message": None}

    async def adapt(self, cluster, minimum=None, maximum=None, active=True, lock=True):
        if not active and not cluster.adaptive:
            # Nothing to do
            return
        cl = cluster.lock if lock else nullcontext()
        async with cl:
            client = AsyncHTTPClient()
            body = json.dumps(
                {"minimum": minimum, "maximum": maximum, "active": active}
            )
            url = "%s/api/adapt" % cluster.api_address
            req = HTTPRequest(
                url,
                method="POST",
                headers={
                    "Authorization": "token %s" % cluster.token,
                    "Content-type": "application/json",
                },
                body=body,
            )
            await client.fetch(req)
            self.db.update_cluster(cluster, adaptive=active)
Exemplo n.º 15
0
class ZMQTerminalInteractiveShell(SingletonConfigurable):
    readline_use = False

    pt_cli = None

    _executing = False
    _execution_state = Unicode('')
    _pending_clearoutput = False
    _eventloop = None
    own_kernel = False  # Changed by ZMQTerminalIPythonApp

    editing_mode = Unicode(
        'emacs',
        config=True,
        help="Shortcut style to use at the prompt. 'vi' or 'emacs'.",
    )

    highlighting_style = Unicode(
        '',
        config=True,
        help="The name of a Pygments style to use for syntax highlighting")

    highlighting_style_overrides = Dict(
        config=True, help="Override highlighting format for specific tokens")

    true_color = Bool(
        False,
        config=True,
        help=("Use 24bit colors instead of 256 colors in prompt highlighting. "
              "If your terminal supports true color, the following command "
              "should print 'TRUECOLOR' in orange: "
              "printf \"\\x1b[38;2;255;100;0mTRUECOLOR\\x1b[0m\\n\""))

    history_load_length = Integer(
        1000, config=True, help="How many history items to load into memory")

    banner = Unicode(
        'Jupyter console {version}\n\n{kernel_banner}',
        config=True,
        help=(
            "Text to display before the first prompt. Will be formatted with "
            "variables {version} and {kernel_banner}."))

    kernel_timeout = Float(
        60,
        config=True,
        help="""Timeout for giving up on a kernel (in seconds).

        On first connect and restart, the console tests whether the
        kernel is running and responsive by sending kernel_info_requests.
        This sets the timeout in seconds for how long the kernel can take
        before being presumed dead.
        """)

    image_handler = Enum(('PIL', 'stream', 'tempfile', 'callable'),
                         'PIL',
                         config=True,
                         allow_none=True,
                         help="""
        Handler for image type output.  This is useful, for example,
        when connecting to the kernel in which pylab inline backend is
        activated.  There are four handlers defined.  'PIL': Use
        Python Imaging Library to popup image; 'stream': Use an
        external program to show the image.  Image will be fed into
        the STDIN of the program.  You will need to configure
        `stream_image_handler`; 'tempfile': Use an external program to
        show the image.  Image will be saved in a temporally file and
        the program is called with the temporally file.  You will need
        to configure `tempfile_image_handler`; 'callable': You can set
        any Python callable which is called with the image data.  You
        will need to configure `callable_image_handler`.
        """)

    stream_image_handler = List(config=True,
                                help="""
        Command to invoke an image viewer program when you are using
        'stream' image handler.  This option is a list of string where
        the first element is the command itself and reminders are the
        options for the command.  Raw image data is given as STDIN to
        the program.
        """)

    tempfile_image_handler = List(config=True,
                                  help="""
        Command to invoke an image viewer program when you are using
        'tempfile' image handler.  This option is a list of string
        where the first element is the command itself and reminders
        are the options for the command.  You can use {file} and
        {format} in the string to represent the location of the
        generated image file and image format.
        """)

    callable_image_handler = Any(config=True,
                                 help="""
        Callable object called via 'callable' image handler with one
        argument, `data`, which is `msg["content"]["data"]` where
        `msg` is the message from iopub channel.  For example, you can
        find base64 encoded PNG data as `data['image/png']`. If your function
        can't handle the data supplied, it should return `False` to indicate
        this.
        """)

    mime_preference = List(
        default_value=['image/png', 'image/jpeg', 'image/svg+xml'],
        config=True,
        help="""
        Preferred object representation MIME type in order.  First
        matched MIME type will be used.
        """)

    use_kernel_is_complete = Bool(
        True,
        config=True,
        help="""Whether to use the kernel's is_complete message
        handling. If False, then the frontend will use its
        own is_complete handler.
        """)
    kernel_is_complete_timeout = Float(
        1,
        config=True,
        help="""Timeout (in seconds) for giving up on a kernel's is_complete
        response.

        If the kernel does not respond at any point within this time,
        the kernel will no longer be asked if code is complete, and the
        console will default to the built-in is_complete test.
        """)

    # This is configurable on JupyterConsoleApp; this copy is not configurable
    # to avoid a duplicate config option.
    confirm_exit = Bool(True,
                        help="""Set to display confirmation dialog on exit.
        You can always use 'exit' or 'quit', to force a
        direct exit without any confirmation.
        """)

    highlight_matching_brackets = Bool(
        True,
        help="Highlight matching brackets.",
    ).tag(config=True)

    manager = Instance('jupyter_client.KernelManager', allow_none=True)
    client = Instance('jupyter_client.KernelClient', allow_none=True)

    def _client_changed(self, name, old, new):
        self.session_id = new.session.session

    session_id = Unicode()

    def _banner1_default(self):
        return "Jupyter Console {version}\n".format(version=__version__)

    simple_prompt = Bool(
        False,
        help="""Use simple fallback prompt. Features may be limited.""").tag(
            config=True)

    def __init__(self, **kwargs):
        # This is where traits with a config_key argument are updated
        # from the values on config.
        super(ZMQTerminalInteractiveShell, self).__init__(**kwargs)
        self.configurables = [self]

        self.init_history()
        self.init_completer()
        self.init_io()

        self.init_kernel_info()
        self.init_prompt_toolkit_cli()
        self.keep_running = True
        self.execution_count = 1

    def init_completer(self):
        """Initialize the completion machinery.

        This creates completion machinery that can be used by client code,
        either interactively in-process (typically triggered by the readline
        library), programmatically (such as in test suites) or out-of-process
        (typically over the network by remote frontends).
        """
        self.Completer = ZMQCompleter(self, self.client, config=self.config)

    def init_history(self):
        """Sets up the command history. """
        self.history_manager = ZMQHistoryManager(client=self.client)
        self.configurables.append(self.history_manager)

    def get_prompt_tokens(self, ec=None):
        if ec is None:
            ec = self.execution_count
        return [
            (Token.Prompt, 'In ['),
            (Token.PromptNum, str(ec)),
            (Token.Prompt, ']: '),
        ]

    def get_continuation_tokens(self, width):
        return [
            (Token.Prompt, (' ' * (width - 2)) + ': '),
        ]

    def get_out_prompt_tokens(self):
        return [(Token.OutPrompt, 'Out['),
                (Token.OutPromptNum, str(self.execution_count)),
                (Token.OutPrompt, ']: ')]

    def print_out_prompt(self):
        tokens = self.get_out_prompt_tokens()
        print_formatted_text(PygmentsTokens(tokens),
                             end='',
                             style=self.pt_cli.app.style)

    def get_remote_prompt_tokens(self):
        return [
            (Token.RemotePrompt, self.other_output_prefix),
        ]

    def print_remote_prompt(self, ec=None):
        tokens = self.get_remote_prompt_tokens() + self.get_prompt_tokens(
            ec=ec)
        print_formatted_text(PygmentsTokens(tokens),
                             end='',
                             style=self.pt_cli.app.style)

    kernel_info = {}

    def init_kernel_info(self):
        """Wait for a kernel to be ready, and store kernel info"""
        timeout = self.kernel_timeout
        tic = time.time()
        self.client.hb_channel.unpause()
        msg_id = self.client.kernel_info()
        while True:
            try:
                reply = self.client.get_shell_msg(timeout=1)
            except Empty:
                if (time.time() - tic) > timeout:
                    raise RuntimeError(
                        "Kernel didn't respond to kernel_info_request")
            else:
                if reply['parent_header'].get('msg_id') == msg_id:
                    self.kernel_info = reply['content']
                    return

    def show_banner(self):
        print(self.banner.format(version=__version__,
                                 kernel_banner=self.kernel_info.get(
                                     'banner', '')),
              end='',
              flush=True)

    def init_prompt_toolkit_cli(self):
        if self.simple_prompt or ('JUPYTER_CONSOLE_TEST' in os.environ):
            # Simple restricted interface for tests so we can find prompts with
            # pexpect. Multi-line input not supported.
            @asyncio.coroutine
            def prompt():
                prompt = 'In [%d]: ' % self.execution_count
                raw = yield from async_input(prompt)
                return raw

            self.prompt_for_code = prompt
            self.print_out_prompt = \
                lambda: print('Out[%d]: ' % self.execution_count, end='')
            return

        kb = KeyBindings()
        insert_mode = vi_insert_mode | emacs_insert_mode

        @kb.add("enter",
                filter=(has_focus(DEFAULT_BUFFER)
                        & ~has_selection
                        & insert_mode))
        def _(event):
            b = event.current_buffer
            d = b.document
            if not (d.on_last_line or d.cursor_position_row >=
                    d.line_count - d.empty_line_count_at_the_end()):
                b.newline()
                return

            # Pressing enter flushes any pending display. This also ensures
            # the displayed execution_count is correct.
            self.handle_iopub()

            more, indent = self.check_complete(d.text)

            if (not more) and b.accept_handler:
                b.validate_and_handle()
            else:
                b.insert_text('\n' + indent)

        @kb.add("c-c", filter=has_focus(DEFAULT_BUFFER))
        def _(event):
            event.current_buffer.reset()

        @kb.add("c-\\", filter=has_focus(DEFAULT_BUFFER))
        def _(event):
            raise EOFError

        @kb.add("c-z",
                filter=Condition(lambda: suspend_to_background_supported()))
        def _(event):
            event.cli.suspend_to_background()

        # Pre-populate history from IPython's history database
        history = InMemoryHistory()
        last_cell = u""
        for _, _, cell in self.history_manager.get_tail(
                self.history_load_length, include_latest=True):
            # Ignore blank lines and consecutive duplicates
            cell = cell.rstrip()
            if cell and (cell != last_cell):
                history.append_string(cell)

        style_overrides = {
            Token.Prompt: '#009900',
            Token.PromptNum: '#00ff00 bold',
            Token.OutPrompt: '#ff2200',
            Token.OutPromptNum: '#ff0000 bold',
            Token.RemotePrompt: '#999900',
        }
        if self.highlighting_style:
            style_cls = get_style_by_name(self.highlighting_style)
        else:
            style_cls = get_style_by_name('default')
            # The default theme needs to be visible on both a dark background
            # and a light background, because we can't tell what the terminal
            # looks like. These tweaks to the default theme help with that.
            style_overrides.update({
                Token.Number: '#007700',
                Token.Operator: 'noinherit',
                Token.String: '#BB6622',
                Token.Name.Function: '#2080D0',
                Token.Name.Class: 'bold #2080D0',
                Token.Name.Namespace: 'bold #2080D0',
            })
        style_overrides.update(self.highlighting_style_overrides)
        style = merge_styles([
            style_from_pygments_cls(style_cls),
            style_from_pygments_dict(style_overrides),
        ])

        editing_mode = getattr(EditingMode, self.editing_mode.upper())
        langinfo = self.kernel_info.get('language_info', {})
        lexer = langinfo.get('pygments_lexer', langinfo.get('name', 'text'))

        # If enabled in the settings, highlight matching brackets
        # when the DEFAULT_BUFFER has the focus
        input_processors = [
            ConditionalProcessor(
                processor=HighlightMatchingBracketProcessor(chars='[](){}'),
                filter=has_focus(DEFAULT_BUFFER) & ~is_done
                & Condition(lambda: self.highlight_matching_brackets))
        ]

        # Tell prompt_toolkit to use the asyncio event loop.
        # Obsolete in prompt_toolkit.v3
        if not PTK3:
            use_asyncio_event_loop()

        self.pt_cli = PromptSession(
            message=(lambda: PygmentsTokens(self.get_prompt_tokens())),
            multiline=True,
            editing_mode=editing_mode,
            lexer=PygmentsLexer(get_pygments_lexer(lexer)),
            prompt_continuation=(
                lambda width, lineno, is_soft_wrap: PygmentsTokens(
                    self.get_continuation_tokens(width))),
            key_bindings=kb,
            history=history,
            completer=JupyterPTCompleter(self.Completer),
            enable_history_search=True,
            style=style,
            input_processors=input_processors,
            color_depth=(ColorDepth.TRUE_COLOR if self.true_color else None),
        )

    @asyncio.coroutine
    def prompt_for_code(self):
        if self.next_input:
            default = self.next_input
            self.next_input = None
        else:
            default = ''

        if PTK3:
            text = yield from self.pt_cli.prompt_async(default=default)
        else:
            text = yield from self.pt_cli.prompt(default=default, async_=True)

        return text

    def init_io(self):
        if sys.platform not in {'win32', 'cli'}:
            return

        import colorama
        colorama.init()

    def check_complete(self, code):
        if self.use_kernel_is_complete:
            msg_id = self.client.is_complete(code)
            try:
                return self.handle_is_complete_reply(
                    msg_id, timeout=self.kernel_is_complete_timeout)
            except SyntaxError:
                return False, ""
        else:
            lines = code.splitlines()
            if len(lines):
                more = (lines[-1] != "")
                return more, ""
            else:
                return False, ""

    def ask_exit(self):
        self.keep_running = False

    # This is set from payloads in handle_execute_reply
    next_input = None

    def pre_prompt(self):
        if self.next_input:
            # We can't set the buffer here, because it will be reset just after
            # this. Adding a callable to pre_run_callables does what we need
            # after the buffer is reset.
            s = self.next_input

            def set_doc():
                self.pt_cli.app.buffer.document = Document(s)

            if hasattr(self.pt_cli, 'pre_run_callables'):
                self.pt_cli.app.pre_run_callables.append(set_doc)
            else:
                # Older version of prompt_toolkit; it's OK to set the document
                # directly here.
                set_doc()
            self.next_input = None

    @asyncio.coroutine
    def interact(self, loop=None, display_banner=None):
        while self.keep_running:
            print('\n', end='')

            try:
                code = yield from self.prompt_for_code()
            except EOFError:
                if (not self.confirm_exit) or \
                        ask_yes_no('Do you really want to exit ([y]/n)?', 'y', 'n'):
                    self.ask_exit()

            else:
                if code:
                    self.run_cell(code, store_history=True)

    def mainloop(self):
        self.keepkernel = not self.own_kernel
        loop = asyncio.get_event_loop()
        # An extra layer of protection in case someone mashing Ctrl-C breaks
        # out of our internal code.
        while True:
            try:
                tasks = [self.interact(loop=loop)]

                if self.include_other_output:
                    # only poll the iopub channel asynchronously if we
                    # wish to include external content
                    tasks.append(self.handle_external_iopub(loop=loop))

                main_task = asyncio.wait(tasks,
                                         loop=loop,
                                         return_when=asyncio.FIRST_COMPLETED)
                _, pending = loop.run_until_complete(main_task)

                for task in pending:
                    task.cancel()
                try:
                    loop.run_until_complete(asyncio.gather(*pending))
                except asyncio.CancelledError:
                    pass
                loop.stop()
                loop.close()
                break
            except KeyboardInterrupt:
                print("\nKeyboardInterrupt escaped interact()\n")

        if self._eventloop:
            self._eventloop.close()
        if self.keepkernel and not self.own_kernel:
            print('keeping kernel alive')
        elif self.keepkernel and self.own_kernel:
            print("owning kernel, cannot keep it alive")
            self.client.shutdown()
        else:
            print("Shutting down kernel")
            self.client.shutdown()

    def run_cell(self, cell, store_history=True):
        """Run a complete IPython cell.

        Parameters
        ----------
        cell : str
          The code (including IPython code such as %magic functions) to run.
        store_history : bool
          If True, the raw and translated cell will be stored in IPython's
          history. For user code calling back into IPython's machinery, this
          should be set to False.
        """
        if (not cell) or cell.isspace():
            # pressing enter flushes any pending display
            self.handle_iopub()
            return

        # flush stale replies, which could have been ignored, due to missed heartbeats
        while self.client.shell_channel.msg_ready():
            self.client.shell_channel.get_msg()
        # execute takes 'hidden', which is the inverse of store_hist
        msg_id = self.client.execute(cell, not store_history)

        # first thing is wait for any side effects (output, stdin, etc.)
        self._executing = True
        self._execution_state = "busy"
        while self._execution_state != 'idle' and self.client.is_alive():
            try:
                self.handle_input_request(msg_id, timeout=0.05)
            except Empty:
                # display intermediate print statements, etc.
                self.handle_iopub(msg_id)
            except ZMQError as e:
                # Carry on if polling was interrupted by a signal
                if e.errno != errno.EINTR:
                    raise

        # after all of that is done, wait for the execute reply
        while self.client.is_alive():
            try:
                self.handle_execute_reply(msg_id, timeout=0.05)
            except Empty:
                pass
            else:
                break
        self._executing = False

    #-----------------
    # message handlers
    #-----------------

    def handle_execute_reply(self, msg_id, timeout=None):
        msg = self.client.shell_channel.get_msg(block=False, timeout=timeout)
        if msg["parent_header"].get("msg_id", None) == msg_id:

            self.handle_iopub(msg_id)

            content = msg["content"]
            status = content['status']

            if status == 'aborted':
                self.write('Aborted\n')
                return
            elif status == 'ok':
                # handle payloads
                for item in content.get("payload", []):
                    source = item['source']
                    if source == 'page':
                        page.page(item['data']['text/plain'])
                    elif source == 'set_next_input':
                        self.next_input = item['text']
                    elif source == 'ask_exit':
                        self.keepkernel = item.get('keepkernel', False)
                        self.ask_exit()

            elif status == 'error':
                pass

            self.execution_count = int(content["execution_count"] + 1)

    def handle_is_complete_reply(self, msg_id, timeout=None):
        """
        Wait for a repsonse from the kernel, and return two values:
            more? - (boolean) should the frontend ask for more input
            indent - an indent string to prefix the input
        Overloaded methods may want to examine the comeplete source. Its is
        in the self._source_lines_buffered list.
        """
        ## Get the is_complete response:
        msg = None
        try:
            msg = self.client.shell_channel.get_msg(block=True,
                                                    timeout=timeout)
        except Empty:
            warn('The kernel did not respond to an is_complete_request. '
                 'Setting `use_kernel_is_complete` to False.')
            self.use_kernel_is_complete = False
            return False, ""
        ## Handle response:
        if msg["parent_header"].get("msg_id", None) != msg_id:
            warn(
                'The kernel did not respond properly to an is_complete_request: %s.'
                % str(msg))
            return False, ""
        else:
            status = msg["content"].get("status", None)
            indent = msg["content"].get("indent", "")
        ## Return more? and indent string
        if status == "complete":
            return False, indent
        elif status == "incomplete":
            return True, indent
        elif status == "invalid":
            raise SyntaxError()
        elif status == "unknown":
            return False, indent
        else:
            warn('The kernel sent an invalid is_complete_reply status: "%s".' %
                 status)
            return False, indent

    include_other_output = Bool(False,
                                config=True,
                                help="""Whether to include output from clients
        other than this one sharing the same kernel.
        """)
    other_output_prefix = Unicode(
        "Remote ",
        config=True,
        help="""Prefix to add to outputs coming from clients other than this one.

        Only relevant if include_other_output is True.
        """)

    def from_here(self, msg):
        """Return whether a message is from this session"""
        return msg['parent_header'].get("session",
                                        self.session_id) == self.session_id

    def include_output(self, msg):
        """Return whether we should include a given output message"""
        from_here = self.from_here(msg)
        if msg['msg_type'] == 'execute_input':
            # only echo inputs not from here
            return self.include_other_output and not from_here

        if self.include_other_output:
            return True
        else:
            return from_here

    @asyncio.coroutine
    def handle_external_iopub(self, loop=None):
        while self.keep_running:
            # we need to check for keep_running from time to time as
            # we are blocking in an executor block which cannot be cancelled.
            poll_result = yield from loop.run_in_executor(
                None, self.client.iopub_channel.socket.poll, 500)
            if (poll_result):
                self.handle_iopub()

    def handle_iopub(self, msg_id=''):
        """Process messages on the IOPub channel

           This method consumes and processes messages on the IOPub channel,
           such as stdout, stderr, execute_result and status.

           It only displays output that is caused by this session.
        """
        while self.client.iopub_channel.msg_ready():
            sub_msg = self.client.iopub_channel.get_msg()
            msg_type = sub_msg['header']['msg_type']
            parent = sub_msg["parent_header"]

            # Update execution_count in case it changed in another session
            if msg_type == "execute_input":
                self.execution_count = int(
                    sub_msg["content"]["execution_count"]) + 1

            if self.include_output(sub_msg):
                if msg_type == 'status':
                    self._execution_state = sub_msg["content"][
                        "execution_state"]

                elif msg_type == 'stream':
                    if sub_msg["content"]["name"] == "stdout":
                        if self._pending_clearoutput:
                            print("\r", end="")
                            self._pending_clearoutput = False
                        print(sub_msg["content"]["text"], end="")
                        sys.stdout.flush()
                    elif sub_msg["content"]["name"] == "stderr":
                        if self._pending_clearoutput:
                            print("\r", file=sys.stderr, end="")
                            self._pending_clearoutput = False
                        print(sub_msg["content"]["text"],
                              file=sys.stderr,
                              end="")
                        sys.stderr.flush()

                elif msg_type == 'execute_result':
                    if self._pending_clearoutput:
                        print("\r", end="")
                        self._pending_clearoutput = False
                    self.execution_count = int(
                        sub_msg["content"]["execution_count"])
                    if not self.from_here(sub_msg):
                        sys.stdout.write(self.other_output_prefix)
                    format_dict = sub_msg["content"]["data"]
                    self.handle_rich_data(format_dict)

                    if 'text/plain' not in format_dict:
                        continue

                    # prompt_toolkit writes the prompt at a slightly lower level,
                    # so flush streams first to ensure correct ordering.
                    sys.stdout.flush()
                    sys.stderr.flush()
                    self.print_out_prompt()
                    text_repr = format_dict['text/plain']
                    if '\n' in text_repr:
                        # For multi-line results, start a new line after prompt
                        print()
                    print(text_repr)

                    # Remote: add new prompt
                    if not self.from_here(sub_msg):
                        sys.stdout.write('\n')
                        sys.stdout.flush()
                        self.print_remote_prompt()

                elif msg_type == 'display_data':
                    data = sub_msg["content"]["data"]
                    handled = self.handle_rich_data(data)
                    if not handled:
                        if not self.from_here(sub_msg):
                            sys.stdout.write(self.other_output_prefix)
                        # if it was an image, we handled it by now
                        if 'text/plain' in data:
                            print(data['text/plain'])

                # If execute input: print it
                elif msg_type == 'execute_input':
                    content = sub_msg['content']
                    ec = content.get('execution_count',
                                     self.execution_count - 1)

                    # New line
                    sys.stdout.write('\n')
                    sys.stdout.flush()

                    # With `Remote In [3]: `
                    self.print_remote_prompt(ec=ec)

                    # And the code
                    sys.stdout.write(content['code'] + '\n')

                elif msg_type == 'clear_output':
                    if sub_msg["content"]["wait"]:
                        self._pending_clearoutput = True
                    else:
                        print("\r", end="")

                elif msg_type == 'error':
                    for frame in sub_msg["content"]["traceback"]:
                        print(frame, file=sys.stderr)

    _imagemime = {
        'image/png': 'png',
        'image/jpeg': 'jpeg',
        'image/svg+xml': 'svg',
    }

    def handle_rich_data(self, data):
        for mime in self.mime_preference:
            if mime in data and mime in self._imagemime:
                if self.handle_image(data, mime):
                    return True
        return False

    def handle_image(self, data, mime):
        handler = getattr(self, 'handle_image_{0}'.format(self.image_handler),
                          None)
        if handler:
            return handler(data, mime)

    def handle_image_PIL(self, data, mime):
        if mime not in ('image/png', 'image/jpeg'):
            return False
        try:
            from PIL import Image, ImageShow
        except ImportError:
            return False
        raw = base64.decodebytes(data[mime].encode('ascii'))
        img = Image.open(BytesIO(raw))
        return ImageShow.show(img)

    def handle_image_stream(self, data, mime):
        raw = base64.decodebytes(data[mime].encode('ascii'))
        imageformat = self._imagemime[mime]
        fmt = dict(format=imageformat)
        args = [s.format(**fmt) for s in self.stream_image_handler]
        with open(os.devnull, 'w') as devnull:
            proc = subprocess.Popen(args,
                                    stdin=subprocess.PIPE,
                                    stdout=devnull,
                                    stderr=devnull)
            proc.communicate(raw)
        return (proc.returncode == 0)

    def handle_image_tempfile(self, data, mime):
        raw = base64.decodebytes(data[mime].encode('ascii'))
        imageformat = self._imagemime[mime]
        filename = 'tmp.{0}'.format(imageformat)
        with NamedFileInTemporaryDirectory(filename) as f, \
                open(os.devnull, 'w') as devnull:
            f.write(raw)
            f.flush()
            fmt = dict(file=f.name, format=imageformat)
            args = [s.format(**fmt) for s in self.tempfile_image_handler]
            rc = subprocess.call(args, stdout=devnull, stderr=devnull)
        return (rc == 0)

    def handle_image_callable(self, data, mime):
        res = self.callable_image_handler(data)
        if res is not False:
            # If handler func returns e.g. None, assume it has handled the data.
            res = True
        return res

    def handle_input_request(self, msg_id, timeout=0.1):
        """ Method to capture raw_input
        """
        req = self.client.stdin_channel.get_msg(timeout=timeout)
        # in case any iopub came while we were waiting:
        self.handle_iopub(msg_id)
        if msg_id == req["parent_header"].get("msg_id"):
            # wrap SIGINT handler
            real_handler = signal.getsignal(signal.SIGINT)

            def double_int(sig, frame):
                # call real handler (forwards sigint to kernel),
                # then raise local interrupt, stopping local raw_input
                real_handler(sig, frame)
                raise KeyboardInterrupt

            signal.signal(signal.SIGINT, double_int)
            content = req['content']
            read = getpass if content.get('password', False) else input
            try:
                raw_data = read(content["prompt"])
            except EOFError:
                # turn EOFError into EOF character
                raw_data = '\x04'
            except KeyboardInterrupt:
                sys.stdout.write('\n')
                return
            finally:
                # restore SIGINT handler
                signal.signal(signal.SIGINT, real_handler)

            # only send stdin reply if there *was not* another request
            # or execution finished while we were reading.
            if not (self.client.stdin_channel.msg_ready()
                    or self.client.shell_channel.msg_ready()):
                self.client.input(raw_data)
Exemplo n.º 16
0
class FXManager(HasTraits):
    """
    Manages device lighting effects
    """
    current_fx = Tuple(Unicode(allow_none=True),
                       Instance(klass=BaseFX, allow_none=True),
                       default_value=(None, None))

    def __init__(self, driver, fxmod: FXModule, *args, **kwargs):
        """
        :param driver: The UChromaDevice to control
        """
        super(FXManager, self).__init__(*args, **kwargs)
        self._driver = driver
        self._logger = driver.logger
        self._fxmod = fxmod

        driver.restore_prefs.connect(self._restore_prefs)

    def _restore_prefs(self, prefs):
        """
        Restore last FX from preferences
        """
        if prefs.fx is not None:
            args = {}
            if prefs.fx_args is not None:
                args = prefs.fx_args

            self.activate(prefs.fx, **args)

    @property
    def available_fx(self):
        return self._fxmod.available_fx

    def get_fx(self, fx_name) -> BaseFX:
        """
        Get the requested effects implementation.

        Returns the last active object if appropriate.

        :param fx_name: The string name of the effect object
        """
        if self.current_fx[0] == fx_name:
            return self.current_fx[1]

        return self._fxmod.create_fx(fx_name)

    def disable(self) -> bool:
        if 'disable' in self.available_fx:
            return self.activate('disable')
        return False

    def _activate(self, fx_name, fx, future=None):
        # need to do this as a callback if an animation
        # is shutting down
        if fx.apply():
            if fx_name != self.current_fx[0]:
                self.current_fx = (fx_name, fx)
            if fx_name == CUSTOM:
                return True

            self._driver.preferences.fx = fx_name
            argsdict = get_args_dict(fx)
            if len(argsdict) == 0:
                argsdict = None
            self._driver.preferences.fx_args = argsdict
        return True

    def activate(self, fx_name, **kwargs) -> bool:
        fx = self.get_fx(fx_name)
        if fx is None:
            return False

        if fx_name != CUSTOM and fx_name != 'disable':
            for k, v in kwargs.items():
                if fx.has_trait(k):
                    setattr(fx, k, v)

            if self._driver.is_animating:
                self._driver.animation_manager.stop( \
                        cb=functools.partial(self._activate, fx_name, fx))
                return True

        return self._activate(fx_name, fx)
class FrontendWidget(HistoryConsoleWidget, BaseFrontendMixin):
    """ A Qt frontend for a generic Python kernel.
    """

    # The text to show when the kernel is (re)started.
    banner = Unicode(config=True)
    kernel_banner = Unicode()
    # Whether to show the banner
    _display_banner = Bool(False)

    # An option and corresponding signal for overriding the default kernel
    # interrupt behavior.
    custom_interrupt = Bool(False)
    custom_interrupt_requested = QtCore.Signal()

    # An option and corresponding signals for overriding the default kernel
    # restart behavior.
    custom_restart = Bool(False)
    custom_restart_kernel_died = QtCore.Signal(float)
    custom_restart_requested = QtCore.Signal()

    # Whether to automatically show calltips on open-parentheses.
    enable_calltips = Bool(
        True,
        config=True,
        help="Whether to draw information calltips on open-parentheses.")

    clear_on_kernel_restart = Bool(
        True,
        config=True,
        help="Whether to clear the console when the kernel is restarted")

    confirm_restart = Bool(
        True,
        config=True,
        help="Whether to ask for user confirmation when restarting kernel")

    lexer_class = DottedObjectName(config=True,
                                   help="The pygments lexer class to use.")

    def _lexer_class_changed(self, name, old, new):
        lexer_class = import_item(new)
        self.lexer = lexer_class()

    def _lexer_class_default(self):
        if py3compat.PY3:
            return 'pygments.lexers.Python3Lexer'
        else:
            return 'pygments.lexers.PythonLexer'

    lexer = Any()

    def _lexer_default(self):
        lexer_class = import_item(self.lexer_class)
        return lexer_class()

    # Emitted when a user visible 'execute_request' has been submitted to the
    # kernel from the FrontendWidget. Contains the code to be executed.
    executing = QtCore.Signal(object)

    # Emitted when a user-visible 'execute_reply' has been received from the
    # kernel and processed by the FrontendWidget. Contains the response message.
    executed = QtCore.Signal(object)

    # Emitted when an exit request has been received from the kernel.
    exit_requested = QtCore.Signal(object)

    _CallTipRequest = namedtuple('_CallTipRequest', ['id', 'pos'])
    _CompletionRequest = namedtuple('_CompletionRequest', ['id', 'pos'])
    _ExecutionRequest = namedtuple('_ExecutionRequest', ['id', 'kind'])
    _local_kernel = False
    _highlighter = Instance(FrontendHighlighter, allow_none=True)

    # -------------------------------------------------------------------------
    # 'Object' interface
    # -------------------------------------------------------------------------

    def __init__(self, local_kernel=_local_kernel, *args, **kw):
        super(FrontendWidget, self).__init__(*args, **kw)
        # FIXME: remove this when PySide min version is updated past 1.0.7
        # forcefully disable calltips if PySide is < 1.0.7, because they crash
        if qt.QT_API == qt.QT_API_PYSIDE:
            import PySide
            if PySide.__version_info__ < (1, 0, 7):
                self.log.warning("PySide %s < 1.0.7 found; disabling calltips",
                                 PySide.__version__)
                self.enable_calltips = False

        # FrontendWidget protected variables.
        self._bracket_matcher = BracketMatcher(self._control)
        self._call_tip_widget = CallTipWidget(self._control)
        self._copy_raw_action = QtGui.QAction('Copy (Raw Text)', None)
        self._hidden = False
        self._highlighter = FrontendHighlighter(self, lexer=self.lexer)
        self._kernel_manager = None
        self._kernel_client = None
        self._request_info = {}
        self._request_info['execute'] = {}
        self._callback_dict = {}
        self._display_banner = True

        # Configure the ConsoleWidget.
        self.tab_width = 4
        self._set_continuation_prompt('... ')

        # Configure the CallTipWidget.
        self._call_tip_widget.setFont(self.font)
        self.font_changed.connect(self._call_tip_widget.setFont)

        # Configure actions.
        action = self._copy_raw_action
        key = QtCore.Qt.CTRL | QtCore.Qt.SHIFT | QtCore.Qt.Key_C
        action.setEnabled(False)
        action.setShortcut(QtGui.QKeySequence(key))
        action.setShortcutContext(QtCore.Qt.WidgetWithChildrenShortcut)
        action.triggered.connect(self.copy_raw)
        self.copy_available.connect(action.setEnabled)
        self.addAction(action)

        # Connect signal handlers.
        document = self._control.document()
        document.contentsChange.connect(self._document_contents_change)

        # Set flag for whether we are connected via localhost.
        self._local_kernel = local_kernel

        # Whether or not a clear_output call is pending new output.
        self._pending_clearoutput = False

    #---------------------------------------------------------------------------
    # 'ConsoleWidget' public interface
    #---------------------------------------------------------------------------

    def copy(self):
        """ Copy the currently selected text to the clipboard, removing prompts.
        """
        if self._page_control is not None and self._page_control.hasFocus():
            self._page_control.copy()
        elif self._control.hasFocus():
            text = self._control.textCursor().selection().toPlainText()
            if text:
                # Remove prompts.
                lines = text.splitlines()
                lines = map(self._highlighter.transform_classic_prompt, lines)
                lines = map(self._highlighter.transform_ipy_prompt, lines)
                text = '\n'.join(lines)
                # Needed to prevent errors when copying the prompt.
                # See issue 264
                try:
                    was_newline = text[-1] == '\n'
                except IndexError:
                    was_newline = False
                if was_newline:  # user doesn't need newline
                    text = text[:-1]
                QtGui.QApplication.clipboard().setText(text)
        else:
            self.log.debug("frontend widget : unknown copy target")

    #---------------------------------------------------------------------------
    # 'ConsoleWidget' abstract interface
    #---------------------------------------------------------------------------

    def _execute(self, source, hidden):
        """ Execute 'source'. If 'hidden', do not show any output.

        See parent class :meth:`execute` docstring for full details.
        """
        msg_id = self.kernel_client.execute(source, hidden)
        self._request_info['execute'][msg_id] = self._ExecutionRequest(
            msg_id, 'user')
        self._hidden = hidden
        if not hidden:
            self.executing.emit(source)

    def _prompt_started_hook(self):
        """ Called immediately after a new prompt is displayed.
        """
        if not self._reading:
            self._highlighter.highlighting_on = True

    def _prompt_finished_hook(self):
        """ Called immediately after a prompt is finished, i.e. when some input
            will be processed and a new prompt displayed.
        """
        if not self._reading:
            self._highlighter.highlighting_on = False

    def _tab_pressed(self):
        """ Called when the tab key is pressed. Returns whether to continue
            processing the event.
        """
        # Perform tab completion if:
        # 1) The cursor is in the input buffer.
        # 2) There is a non-whitespace character before the cursor.
        # 3) There is no active selection.
        text = self._get_input_buffer_cursor_line()
        if text is None:
            return False
        non_ws_before = bool(
            text[:self._get_input_buffer_cursor_column()].strip())
        complete = non_ws_before and self._get_cursor().selectedText() == ''
        if complete:
            self._complete()
        return not complete

    #---------------------------------------------------------------------------
    # 'ConsoleWidget' protected interface
    #---------------------------------------------------------------------------

    def _context_menu_make(self, pos):
        """ Reimplemented to add an action for raw copy.
        """
        menu = super(FrontendWidget, self)._context_menu_make(pos)
        for before_action in menu.actions():
            if before_action.shortcut().matches(QtGui.QKeySequence.Paste) == \
                    QtGui.QKeySequence.ExactMatch:
                menu.insertAction(before_action, self._copy_raw_action)
                break
        return menu

    def request_interrupt_kernel(self):
        if self._executing:
            self.interrupt_kernel()

    def request_restart_kernel(self):
        message = 'Are you sure you want to restart the kernel?'
        self.restart_kernel(message, now=False)

    def _event_filter_console_keypress(self, event):
        """ Reimplemented for execution interruption and smart backspace.
        """
        key = event.key()
        if self._control_key_down(event.modifiers(), include_command=False):

            if key == QtCore.Qt.Key_C and self._executing:
                self.request_interrupt_kernel()
                return True

            elif key == QtCore.Qt.Key_Period:
                self.request_restart_kernel()
                return True

        elif not event.modifiers() & QtCore.Qt.AltModifier:

            # Smart backspace: remove four characters in one backspace if:
            # 1) everything left of the cursor is whitespace
            # 2) the four characters immediately left of the cursor are spaces
            if key == QtCore.Qt.Key_Backspace:
                col = self._get_input_buffer_cursor_column()
                cursor = self._control.textCursor()
                if col > 3 and not cursor.hasSelection():
                    text = self._get_input_buffer_cursor_line()[:col]
                    if text.endswith('    ') and not text.strip():
                        cursor.movePosition(QtGui.QTextCursor.Left,
                                            QtGui.QTextCursor.KeepAnchor, 4)
                        cursor.removeSelectedText()
                        return True

        return super(FrontendWidget,
                     self)._event_filter_console_keypress(event)

    #---------------------------------------------------------------------------
    # 'BaseFrontendMixin' abstract interface
    #---------------------------------------------------------------------------
    def _handle_clear_output(self, msg):
        """Handle clear output messages."""
        if self.include_output(msg):
            wait = msg['content'].get('wait', True)
            if wait:
                self._pending_clearoutput = True
            else:
                self.clear_output()

    def _silent_exec_callback(self, expr, callback):
        """Silently execute `expr` in the kernel and call `callback` with reply

        the `expr` is evaluated silently in the kernel (without) output in
        the frontend. Call `callback` with the
        `repr <http://docs.python.org/library/functions.html#repr> `_ as first argument

        Parameters
        ----------
        expr : string
            valid string to be executed by the kernel.
        callback : function
            function accepting one argument, as a string. The string will be
            the `repr` of the result of evaluating `expr`

        The `callback` is called with the `repr()` of the result of `expr` as
        first argument. To get the object, do `eval()` on the passed value.

        See Also
        --------
        _handle_exec_callback : private method, deal with calling callback with reply

        """

        # generate uuid, which would be used as an indication of whether or
        # not the unique request originated from here (can use msg id ?)
        local_uuid = str(uuid.uuid1())
        msg_id = self.kernel_client.execute(
            '', silent=True, user_expressions={local_uuid: expr})
        self._callback_dict[local_uuid] = callback
        self._request_info['execute'][msg_id] = self._ExecutionRequest(
            msg_id, 'silent_exec_callback')

    def _handle_exec_callback(self, msg):
        """Execute `callback` corresponding to `msg` reply, after ``_silent_exec_callback``

        Parameters
        ----------
        msg : raw message send by the kernel containing an `user_expressions`
                and having a 'silent_exec_callback' kind.

        Notes
        -----
        This function will look for a `callback` associated with the
        corresponding message id. Association has been made by
        `_silent_exec_callback`. `callback` is then called with the `repr()`
        of the value of corresponding `user_expressions` as argument.
        `callback` is then removed from the known list so that any message
        coming again with the same id won't trigger it.
        """
        user_exp = msg['content'].get('user_expressions')
        if not user_exp:
            return
        for expression in user_exp:
            if expression in self._callback_dict:
                self._callback_dict.pop(expression)(user_exp[expression])

    def _handle_execute_reply(self, msg):
        """ Handles replies for code execution.
        """
        self.log.debug("execute_reply: %s", msg.get('content', ''))
        msg_id = msg['parent_header']['msg_id']
        info = self._request_info['execute'].get(msg_id)
        # unset reading flag, because if execute finished, raw_input can't
        # still be pending.
        self._reading = False
        # Note:  If info is NoneType, this is ignored
        if info and info.kind == 'user' and not self._hidden:
            # Make sure that all output from the SUB channel has been processed
            # before writing a new prompt.
            self.kernel_client.iopub_channel.flush()

            # Reset the ANSI style information to prevent bad text in stdout
            # from messing up our colors. We're not a true terminal so we're
            # allowed to do this.
            if self.ansi_codes:
                self._ansi_processor.reset_sgr()

            content = msg['content']
            status = content['status']
            if status == 'ok':
                self._process_execute_ok(msg)
            elif status == 'aborted':
                self._process_execute_abort(msg)

            self._show_interpreter_prompt_for_reply(msg)
            self.executed.emit(msg)
            self._request_info['execute'].pop(msg_id)
        elif info and info.kind == 'silent_exec_callback' and not self._hidden:
            self._handle_exec_callback(msg)
            self._request_info['execute'].pop(msg_id)
        elif info and not self._hidden:
            raise RuntimeError("Unknown handler for %s" % info.kind)

    def _handle_error(self, msg):
        """ Handle error messages.
        """
        self._process_execute_error(msg)

    def _handle_input_request(self, msg):
        """ Handle requests for raw_input.
        """
        self.log.debug("input: %s", msg.get('content', ''))
        if self._hidden:
            raise RuntimeError(
                'Request for raw input during hidden execution.')

        # Make sure that all output from the SUB channel has been processed
        # before entering readline mode.
        self.kernel_client.iopub_channel.flush()

        def callback(line):
            self.kernel_client.input(line)

        if self._reading:
            self.log.debug(
                "Got second input request, assuming first was interrupted.")
            self._reading = False
        self._readline(msg['content']['prompt'],
                       callback=callback,
                       password=msg['content']['password'])

    def _kernel_restarted_message(self, died=True):
        msg = "Kernel died, restarting" if died else "Kernel restarting"
        self._append_html("<br>%s<hr><br>" % msg, before_prompt=False)

    def _handle_kernel_died(self, since_last_heartbeat):
        """Handle the kernel's death (if we do not own the kernel).
        """
        self.log.warning("kernel died: %s", since_last_heartbeat)
        if self.custom_restart:
            self.custom_restart_kernel_died.emit(since_last_heartbeat)
        else:
            self._kernel_restarted_message(died=True)
            self.reset()

    def _handle_kernel_restarted(self, died=True):
        """Notice that the autorestarter restarted the kernel.

        There's nothing to do but show a message.
        """
        self.log.warning("kernel restarted")
        self._kernel_restarted_message(died=died)
        self.reset()

    def _handle_inspect_reply(self, rep):
        """Handle replies for call tips."""
        self.log.debug("oinfo: %s", rep.get('content', ''))
        cursor = self._get_cursor()
        info = self._request_info.get('call_tip')
        if info and info.id == rep['parent_header']['msg_id'] and \
                info.pos == cursor.position():
            content = rep['content']
            if content.get('status') == 'ok' and content.get('found', False):
                self._call_tip_widget.show_inspect_data(content)

    def _handle_execute_result(self, msg):
        """ Handle display hook output.
        """
        self.log.debug("execute_result: %s", msg.get('content', ''))
        if self.include_output(msg):
            self.flush_clearoutput()
            text = msg['content']['data']
            self._append_plain_text(text + '\n', before_prompt=True)

    def _handle_stream(self, msg):
        """ Handle stdout, stderr, and stdin.
        """
        self.log.debug("stream: %s", msg.get('content', ''))
        if self.include_output(msg):
            self.flush_clearoutput()
            self.append_stream(msg['content']['text'])

    def _handle_shutdown_reply(self, msg):
        """ Handle shutdown signal, only if from other console.
        """
        self.log.debug("shutdown: %s", msg.get('content', ''))
        restart = msg.get('content', {}).get('restart', False)
        if not self._hidden and not self.from_here(msg):
            # got shutdown reply, request came from session other than ours
            if restart:
                # someone restarted the kernel, handle it
                self._handle_kernel_restarted(died=False)
            else:
                # kernel was shutdown permanently
                # this triggers exit_requested if the kernel was local,
                # and a dialog if the kernel was remote,
                # so we don't suddenly clear the qtconsole without asking.
                if self._local_kernel:
                    self.exit_requested.emit(self)
                else:
                    title = self.window().windowTitle()
                    reply = QtGui.QMessageBox.question(
                        self, title, "Kernel has been shutdown permanently. "
                        "Close the Console?", QtGui.QMessageBox.Yes,
                        QtGui.QMessageBox.No)
                    if reply == QtGui.QMessageBox.Yes:
                        self.exit_requested.emit(self)

    def _handle_status(self, msg):
        """Handle status message"""
        # This is where a busy/idle indicator would be triggered,
        # when we make one.
        state = msg['content'].get('execution_state', '')
        if state == 'starting':
            # kernel started while we were running
            if self._executing:
                self._handle_kernel_restarted(died=True)
        elif state == 'idle':
            pass
        elif state == 'busy':
            pass

    def _started_channels(self):
        """ Called when the KernelManager channels have started listening or
            when the frontend is assigned an already listening KernelManager.
        """
        self.reset(clear=True)

    #---------------------------------------------------------------------------
    # 'FrontendWidget' public interface
    #---------------------------------------------------------------------------

    def copy_raw(self):
        """ Copy the currently selected text to the clipboard without attempting
            to remove prompts or otherwise alter the text.
        """
        self._control.copy()

    def interrupt_kernel(self):
        """ Attempts to interrupt the running kernel.
        
        Also unsets _reading flag, to avoid runtime errors
        if raw_input is called again.
        """
        if self.custom_interrupt:
            self._reading = False
            self.custom_interrupt_requested.emit()
        elif self.kernel_manager:
            self._reading = False
            self.kernel_manager.interrupt_kernel()
        else:
            self._append_plain_text(
                'Cannot interrupt a kernel I did not start.\n')

    def reset(self, clear=False):
        """ Resets the widget to its initial state if ``clear`` parameter
        is True, otherwise
        prints a visual indication of the fact that the kernel restarted, but
        does not clear the traces from previous usage of the kernel before it
        was restarted.  With ``clear=True``, it is similar to ``%clear``, but
        also re-writes the banner and aborts execution if necessary.
        """
        if self._executing:
            self._executing = False
            self._request_info['execute'] = {}
        self._reading = False
        self._highlighter.highlighting_on = False

        if clear:
            self._control.clear()
            if self._display_banner:
                self._append_plain_text(self.banner)
                if self.kernel_banner:
                    self._append_plain_text(self.kernel_banner)

        # update output marker for stdout/stderr, so that startup
        # messages appear after banner:
        self._show_interpreter_prompt()

    def restart_kernel(self, message, now=False):
        """ Attempts to restart the running kernel.
        """
        # FIXME: now should be configurable via a checkbox in the dialog.  Right
        # now at least the heartbeat path sets it to True and the manual restart
        # to False.  But those should just be the pre-selected states of a
        # checkbox that the user could override if so desired.  But I don't know
        # enough Qt to go implementing the checkbox now.

        if self.custom_restart:
            self.custom_restart_requested.emit()
            return

        if self.kernel_manager:
            # Pause the heart beat channel to prevent further warnings.
            self.kernel_client.hb_channel.pause()

            # Prompt the user to restart the kernel. Un-pause the heartbeat if
            # they decline. (If they accept, the heartbeat will be un-paused
            # automatically when the kernel is restarted.)
            if self.confirm_restart:
                buttons = QtGui.QMessageBox.Yes | QtGui.QMessageBox.No
                result = QtGui.QMessageBox.question(self, 'Restart kernel?',
                                                    message, buttons)
                do_restart = result == QtGui.QMessageBox.Yes
            else:
                # confirm_restart is False, so we don't need to ask user
                # anything, just do the restart
                do_restart = True
            if do_restart:
                try:
                    self.kernel_manager.restart_kernel(now=now)
                except RuntimeError as e:
                    self._append_plain_text('Error restarting kernel: %s\n' %
                                            e,
                                            before_prompt=True)
                else:
                    self._append_html(
                        "<br>Restarting kernel...\n<hr><br>",
                        before_prompt=True,
                    )
            else:
                self.kernel_client.hb_channel.unpause()

        else:
            self._append_plain_text(
                'Cannot restart a Kernel I did not start\n',
                before_prompt=True)

    def append_stream(self, text):
        """Appends text to the output stream."""
        # Most consoles treat tabs as being 8 space characters. Convert tabs
        # to spaces so that output looks as expected regardless of this
        # widget's tab width.
        text = text.expandtabs(8)
        self._append_plain_text(text, before_prompt=True)

    def flush_clearoutput(self):
        """If a clearoutput is pending, execute it."""
        if self._pending_clearoutput:
            self._pending_clearoutput = False
            self.clear_output()

    def clear_output(self):
        """Clears the current line of output."""
        cursor = self._control.textCursor()
        cursor.beginEditBlock()
        cursor.movePosition(cursor.StartOfLine, cursor.KeepAnchor)
        cursor.insertText('')
        cursor.endEditBlock()

    #---------------------------------------------------------------------------
    # 'FrontendWidget' protected interface
    #---------------------------------------------------------------------------

    def _auto_call_tip(self):
        """Trigger call tip automatically on open parenthesis
        
        Call tips can be requested explcitly with `_call_tip`.
        """
        cursor = self._get_cursor()
        cursor.movePosition(QtGui.QTextCursor.Left)
        if cursor.document().characterAt(cursor.position()) == '(':
            # trigger auto call tip on open paren
            self._call_tip()

    def _call_tip(self):
        """Shows a call tip, if appropriate, at the current cursor location."""
        # Decide if it makes sense to show a call tip
        if not self.enable_calltips or not self.kernel_client.shell_channel.is_alive(
        ):
            return False
        cursor_pos = self._get_input_buffer_cursor_pos()
        code = self.input_buffer
        # Send the metadata request to the kernel
        msg_id = self.kernel_client.inspect(code, cursor_pos)
        pos = self._get_cursor().position()
        self._request_info['call_tip'] = self._CallTipRequest(msg_id, pos)
        return True

    def _complete(self):
        """ Performs completion at the current cursor location.
        """
        # Send the completion request to the kernel
        msg_id = self.kernel_client.complete(
            code=self.input_buffer,
            cursor_pos=self._get_input_buffer_cursor_pos(),
        )
        pos = self._get_cursor().position()
        info = self._CompletionRequest(msg_id, pos)
        self._request_info['complete'] = info

    def _process_execute_abort(self, msg):
        """ Process a reply for an aborted execution request.
        """
        self._append_plain_text("ERROR: execution aborted\n")

    def _process_execute_error(self, msg):
        """ Process a reply for an execution request that resulted in an error.
        """
        content = msg['content']
        # If a SystemExit is passed along, this means exit() was called - also
        # all the ipython %exit magic syntax of '-k' to be used to keep
        # the kernel running
        if content['ename'] == 'SystemExit':
            keepkernel = content['evalue'] == '-k' or content[
                'evalue'] == 'True'
            self._keep_kernel_on_exit = keepkernel
            self.exit_requested.emit(self)
        else:
            traceback = ''.join(content['traceback'])
            self._append_plain_text(traceback)

    def _process_execute_ok(self, msg):
        """ Process a reply for a successful execution request.
        """
        payload = msg['content'].get('payload', [])
        for item in payload:
            if not self._process_execute_payload(item):
                warning = 'Warning: received unknown payload of type %s'
                print(warning % repr(item['source']))

    def _process_execute_payload(self, item):
        """ Process a single payload item from the list of payload items in an
            execution reply. Returns whether the payload was handled.
        """
        # The basic FrontendWidget doesn't handle payloads, as they are a
        # mechanism for going beyond the standard Python interpreter model.
        return False

    def _show_interpreter_prompt(self):
        """ Shows a prompt for the interpreter.
        """
        self._show_prompt('>>> ')

    def _show_interpreter_prompt_for_reply(self, msg):
        """ Shows a prompt for the interpreter given an 'execute_reply' message.
        """
        self._show_interpreter_prompt()

    #------ Signal handlers ----------------------------------------------------

    def _document_contents_change(self, position, removed, added):
        """ Called whenever the document's content changes. Display a call tip
            if appropriate.
        """
        # Calculate where the cursor should be *after* the change:
        position += added

        document = self._control.document()
        if position == self._get_cursor().position():
            self._auto_call_tip()

    #------ Trait default initializers -----------------------------------------

    @default('banner')
    def _banner_default(self):
        """ Returns the standard Python banner.
        """
        banner = 'Python %s on %s\nType "help", "copyright", "credits" or ' \
            '"license" for more information.'
        return banner % (sys.version, sys.platform)
Exemplo n.º 18
0
class Session(Configurable):
    """Object for handling serialization and sending of messages.

    The Session object handles building messages and sending them
    with ZMQ sockets or ZMQStream objects.  Objects can communicate with each
    other over the network via Session objects, and only need to work with the
    dict-based IPython message spec. The Session will handle
    serialization/deserialization, security, and metadata.

    Sessions support configurable serialization via packer/unpacker traits,
    and signing with HMAC digests via the key/keyfile traits.

    Parameters
    ----------

    debug : bool
        whether to trigger extra debugging statements
    packer/unpacker : str : 'json', 'pickle' or import_string
        importstrings for methods to serialize message parts.  If just
        'json' or 'pickle', predefined JSON and pickle packers will be used.
        Otherwise, the entire importstring must be used.

        The functions must accept at least valid JSON input, and output *bytes*.

        For example, to use msgpack:
        packer = 'msgpack.packb', unpacker='msgpack.unpackb'
    pack/unpack : callables
        You can also set the pack/unpack callables for serialization directly.
    session : bytes
        the ID of this Session object.  The default is to generate a new UUID.
    username : unicode
        username added to message headers.  The default is to ask the OS.
    key : bytes
        The key used to initialize an HMAC signature.  If unset, messages
        will not be signed or checked.
    keyfile : filepath
        The file containing a key.  If this is set, `key` will be initialized
        to the contents of the file.

    """

    debug = Bool(False, config=True, help="""Debug output in the Session""")

    check_pid = Bool(
        True,
        config=True,
        help="""Whether to check PID to protect against calls after fork.

        This check can be disabled if fork-safety is handled elsewhere.
        """)

    packer = DottedObjectName(
        'json',
        config=True,
        help="""The name of the packer for serializing messages.
            Should be one of 'json', 'pickle', or an import name
            for a custom callable serializer.""")

    @observe('packer')
    def _packer_changed(self, change):
        new = change['new']
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
            self.unpacker = new
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
            self.unpacker = new
        else:
            self.pack = import_item(str(new))

    unpacker = DottedObjectName(
        'json',
        config=True,
        help="""The name of the unpacker for unserializing messages.
        Only used with custom functions for `packer`.""")

    @observe('unpacker')
    def _unpacker_changed(self, change):
        new = change['new']
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
            self.packer = new
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
            self.packer = new
        else:
            self.unpack = import_item(str(new))

    session = CUnicode('',
                       config=True,
                       help="""The UUID identifying this session.""")

    def _session_default(self):
        u = new_id()
        self.bsession = u.encode('ascii')
        return u

    @observe('session')
    def _session_changed(self, change):
        self.bsession = self.session.encode('ascii')

    # bsession is the session as bytes
    bsession = CBytes(b'')

    username = Unicode(
        os.environ.get("USER", "username"),
        help="""Username for the Session. Default is your system username.""",
        config=True)

    metadata = Dict(
        {},
        config=True,
        help=
        """Metadata dictionary, which serves as the default top-level metadata dict for each message."""
    )

    # if 0, no adapting to do.
    adapt_version = Integer(0)

    # message signature related traits:

    key = CBytes(config=True, help="""execution key, for signing messages.""")

    def _key_default(self):
        return new_id_bytes()

    @observe('key')
    def _key_changed(self, change):
        self._new_auth()

    signature_scheme = Unicode(
        'hmac-sha256',
        config=True,
        help="""The digest scheme used to construct the message signatures.
        Must have the form 'hmac-HASH'.""")

    @observe('signature_scheme')
    def _signature_scheme_changed(self, change):
        new = change['new']
        if not new.startswith('hmac-'):
            raise TraitError(
                "signature_scheme must start with 'hmac-', got %r" % new)
        hash_name = new.split('-', 1)[1]
        try:
            self.digest_mod = getattr(hashlib, hash_name)
        except AttributeError as e:
            raise TraitError("hashlib has no such attribute: %s" %
                             hash_name) from e
        self._new_auth()

    digest_mod = Any()

    def _digest_mod_default(self):
        return hashlib.sha256

    auth = Instance(hmac.HMAC, allow_none=True)

    def _new_auth(self):
        if self.key:
            self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
        else:
            self.auth = None

    digest_history = Set()
    digest_history_size = Integer(
        2**16,
        config=True,
        help="""The maximum number of digests to remember.

        The digest history will be culled when it exceeds this value.
        """)

    keyfile = Unicode('',
                      config=True,
                      help="""path to file containing execution key.""")

    @observe('keyfile')
    def _keyfile_changed(self, change):
        with open(change['new'], 'rb') as f:
            self.key = f.read().strip()

    # for protecting against sends from forks
    pid = Integer()

    # serialization traits:

    pack = Any(default_packer)  # the actual packer function

    @observe('pack')
    def _pack_changed(self, change):
        new = change['new']
        if not callable(new):
            raise TypeError("packer must be callable, not %s" % type(new))

    unpack = Any(default_unpacker)  # the actual packer function

    @observe('unpack')
    def _unpack_changed(self, change):
        # unpacker is not checked - it is assumed to be
        new = change['new']
        if not callable(new):
            raise TypeError("unpacker must be callable, not %s" % type(new))

    # thresholds:
    copy_threshold = Integer(
        2**16,
        config=True,
        help=
        "Threshold (in bytes) beyond which a buffer should be sent without copying."
    )
    buffer_threshold = Integer(
        MAX_BYTES,
        config=True,
        help=
        "Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling."
    )
    item_threshold = Integer(
        MAX_ITEMS,
        config=True,
        help=
        """The maximum number of items for a container to be introspected for custom serialization.
        Containers larger than this are pickled outright.
        """)

    def __init__(self, **kwargs):
        """create a Session object

        Parameters
        ----------

        debug : bool
            whether to trigger extra debugging statements
        packer/unpacker : str : 'json', 'pickle' or import_string
            importstrings for methods to serialize message parts.  If just
            'json' or 'pickle', predefined JSON and pickle packers will be used.
            Otherwise, the entire importstring must be used.

            The functions must accept at least valid JSON input, and output
            *bytes*.

            For example, to use msgpack:
            packer = 'msgpack.packb', unpacker='msgpack.unpackb'
        pack/unpack : callables
            You can also set the pack/unpack callables for serialization
            directly.
        session : unicode (must be ascii)
            the ID of this Session object.  The default is to generate a new
            UUID.
        bsession : bytes
            The session as bytes
        username : unicode
            username added to message headers.  The default is to ask the OS.
        key : bytes
            The key used to initialize an HMAC signature.  If unset, messages
            will not be signed or checked.
        signature_scheme : str
            The message digest scheme. Currently must be of the form 'hmac-HASH',
            where 'HASH' is a hashing function available in Python's hashlib.
            The default is 'hmac-sha256'.
            This is ignored if 'key' is empty.
        keyfile : filepath
            The file containing a key.  If this is set, `key` will be
            initialized to the contents of the file.
        """
        super().__init__(**kwargs)
        self._check_packers()
        self.none = self.pack({})
        # ensure self._session_default() if necessary, so bsession is defined:
        self.session
        self.pid = os.getpid()
        self._new_auth()
        if not self.key:
            get_logger().warning(
                "Message signing is disabled.  This is insecure and not recommended!"
            )

    def clone(self):
        """Create a copy of this Session

        Useful when connecting multiple times to a given kernel.
        This prevents a shared digest_history warning about duplicate digests
        due to multiple connections to IOPub in the same process.

        .. versionadded:: 5.1
        """
        # make a copy
        new_session = type(self)()
        for name in self.traits():
            setattr(new_session, name, getattr(self, name))
        # fork digest_history
        new_session.digest_history = set()
        new_session.digest_history.update(self.digest_history)
        return new_session

    message_count = 0

    @property
    def msg_id(self):
        message_number = self.message_count
        self.message_count += 1
        return '{}_{}'.format(self.session, message_number)

    def _check_packers(self):
        """check packers for datetime support."""
        pack = self.pack
        unpack = self.unpack

        # check simple serialization
        msg = dict(a=[1, 'hi'])
        try:
            packed = pack(msg)
        except Exception as e:
            msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
            if self.packer == 'json':
                jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
            else:
                jsonmsg = ""
            raise ValueError(
                msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)) from e

        # ensure packed message is bytes
        if not isinstance(packed, bytes):
            raise ValueError("message packed to %r, but bytes are required" %
                             type(packed))

        # check that unpack is pack's inverse
        try:
            unpacked = unpack(packed)
            assert unpacked == msg
        except Exception as e:
            msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
            if self.packer == 'json':
                jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
            else:
                jsonmsg = ""
            raise ValueError(
                msg.format(packer=self.packer,
                           unpacker=self.unpacker,
                           e=e,
                           jsonmsg=jsonmsg)) from e

        # check datetime support
        msg = dict(t=utcnow())
        try:
            unpacked = unpack(pack(msg))
            if isinstance(unpacked['t'], datetime):
                raise ValueError("Shouldn't deserialize to datetime")
        except Exception:
            self.pack = lambda o: pack(squash_dates(o))
            self.unpack = lambda s: unpack(s)

    def msg_header(self, msg_type):
        return msg_header(self.msg_id, msg_type, self.username, self.session)

    def msg(self,
            msg_type,
            content=None,
            parent=None,
            header=None,
            metadata=None):
        """Return the nested message dict.

        This format is different from what is sent over the wire. The
        serialize/deserialize methods converts this nested message dict to the wire
        format, which is a list of message parts.
        """
        msg = {}
        header = self.msg_header(msg_type) if header is None else header
        msg['header'] = header
        msg['msg_id'] = header['msg_id']
        msg['msg_type'] = header['msg_type']
        msg['parent_header'] = {} if parent is None else extract_header(parent)
        msg['content'] = {} if content is None else content
        msg['metadata'] = self.metadata.copy()
        if metadata is not None:
            msg['metadata'].update(metadata)
        return msg

    def sign(self, msg_list):
        """Sign a message with HMAC digest. If no auth, return b''.

        Parameters
        ----------
        msg_list : list
            The [p_header,p_parent,p_content] part of the message list.
        """
        if self.auth is None:
            return b''
        h = self.auth.copy()
        for m in msg_list:
            h.update(m)
        return h.hexdigest().encode()

    def serialize(self, msg, ident=None):
        """Serialize the message components to bytes.

        This is roughly the inverse of deserialize. The serialize/deserialize
        methods work with full message lists, whereas pack/unpack work with
        the individual message parts in the message list.

        Parameters
        ----------
        msg : dict or Message
            The next message dict as returned by the self.msg method.

        Returns
        -------
        msg_list : list
            The list of bytes objects to be sent with the format::

                [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
                 p_metadata, p_content, buffer1, buffer2, ...]

            In this list, the ``p_*`` entities are the packed or serialized
            versions, so if JSON is used, these are utf8 encoded JSON strings.
        """
        content = msg.get('content', {})
        if content is None:
            content = self.none
        elif isinstance(content, dict):
            content = self.pack(content)
        elif isinstance(content, bytes):
            # content is already packed, as in a relayed message
            pass
        elif isinstance(content, str):
            # should be bytes, but JSON often spits out unicode
            content = content.encode('utf8')
        else:
            raise TypeError("Content incorrect type: %s" % type(content))

        real_message = [
            self.pack(msg['header']),
            self.pack(msg['parent_header']),
            self.pack(msg['metadata']),
            content,
        ]

        to_send = []

        if isinstance(ident, list):
            # accept list of idents
            to_send.extend(ident)
        elif ident is not None:
            to_send.append(ident)
        to_send.append(DELIM)

        signature = self.sign(real_message)
        to_send.append(signature)

        to_send.extend(real_message)

        return to_send

    def send(self,
             stream,
             msg_or_type,
             content=None,
             parent=None,
             ident=None,
             buffers=None,
             track=False,
             header=None,
             metadata=None):
        """Build and send a message via stream or socket.

        The message format used by this function internally is as follows:

        [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
         buffer1,buffer2,...]

        The serialize/deserialize methods convert the nested message dict into this
        format.

        Parameters
        ----------

        stream : zmq.Socket or ZMQStream
            The socket-like object used to send the data.
        msg_or_type : str or Message/dict
            Normally, msg_or_type will be a msg_type unless a message is being
            sent more than once. If a header is supplied, this can be set to
            None and the msg_type will be pulled from the header.

        content : dict or None
            The content of the message (ignored if msg_or_type is a message).
        header : dict or None
            The header dict for the message (ignored if msg_to_type is a message).
        parent : Message or dict or None
            The parent or parent header describing the parent of this message
            (ignored if msg_or_type is a message).
        ident : bytes or list of bytes
            The zmq.IDENTITY routing path.
        metadata : dict or None
            The metadata describing the message
        buffers : list or None
            The already-serialized buffers to be appended to the message.
        track : bool
            Whether to track.  Only for use with Sockets, because ZMQStream
            objects cannot track messages.


        Returns
        -------
        msg : dict
            The constructed message.
        """
        if not isinstance(stream, zmq.Socket):
            # ZMQStreams and dummy sockets do not support tracking.
            track = False

        if isinstance(msg_or_type, (Message, dict)):
            # We got a Message or message dict, not a msg_type so don't
            # build a new Message.
            msg = msg_or_type
            buffers = buffers or msg.get('buffers', [])
        else:
            msg = self.msg(msg_or_type,
                           content=content,
                           parent=parent,
                           header=header,
                           metadata=metadata)
        if self.check_pid and not os.getpid() == self.pid:
            get_logger().warning(
                "WARNING: attempted to send message from fork\n%s", msg)
            return
        buffers = [] if buffers is None else buffers
        for idx, buf in enumerate(buffers):
            if isinstance(buf, memoryview):
                view = buf
            else:
                try:
                    # check to see if buf supports the buffer protocol.
                    view = memoryview(buf)
                except TypeError as e:
                    raise TypeError(
                        "Buffer objects must support the buffer protocol."
                    ) from e
            # memoryview.contiguous is new in 3.3,
            # just skip the check on Python 2
            if hasattr(view, 'contiguous') and not view.contiguous:
                # zmq requires memoryviews to be contiguous
                raise ValueError("Buffer %i (%r) is not contiguous" %
                                 (idx, buf))

        if self.adapt_version:
            msg = adapt(msg, self.adapt_version)
        to_send = self.serialize(msg, ident)
        to_send.extend(buffers)
        longest = max([len(s) for s in to_send])
        copy = (longest < self.copy_threshold)

        if buffers and track and not copy:
            # only really track when we are doing zero-copy buffers
            tracker = stream.send_multipart(to_send, copy=False, track=True)
        else:
            # use dummy tracker, which will be done immediately
            tracker = DONE
            stream.send_multipart(to_send, copy=copy)

        if self.debug:
            pprint.pprint(msg)
            pprint.pprint(to_send)
            pprint.pprint(buffers)

        msg['tracker'] = tracker

        return msg

    def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
        """Send a raw message via ident path.

        This method is used to send a already serialized message.

        Parameters
        ----------
        stream : ZMQStream or Socket
            The ZMQ stream or socket to use for sending the message.
        msg_list : list
            The serialized list of messages to send. This only includes the
            [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
            the message.
        ident : ident or list
            A single ident or a list of idents to use in sending.
        """
        to_send = []
        if isinstance(ident, bytes):
            ident = [ident]
        if ident is not None:
            to_send.extend(ident)

        to_send.append(DELIM)
        # Don't include buffers in signature (per spec).
        to_send.append(self.sign(msg_list[0:4]))
        to_send.extend(msg_list)
        stream.send_multipart(to_send, flags, copy=copy)

    def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
        """Receive and unpack a message.

        Parameters
        ----------
        socket : ZMQStream or Socket
            The socket or stream to use in receiving.

        Returns
        -------
        [idents], msg
            [idents] is a list of idents and msg is a nested message dict of
            same format as self.msg returns.
        """
        if isinstance(socket, ZMQStream):
            socket = socket.socket
        try:
            msg_list = socket.recv_multipart(mode, copy=copy)
        except zmq.ZMQError as e:
            if e.errno == zmq.EAGAIN:
                # We can convert EAGAIN to None as we know in this case
                # recv_multipart won't return None.
                return None, None
            else:
                raise
        # split multipart message into identity list and message dict
        # invalid large messages can cause very expensive string comparisons
        idents, msg_list = self.feed_identities(msg_list, copy)
        try:
            return idents, self.deserialize(msg_list,
                                            content=content,
                                            copy=copy)
        except Exception as e:
            # TODO: handle it
            raise e

    def feed_identities(self, msg_list, copy=True):
        """Split the identities from the rest of the message.

        Feed until DELIM is reached, then return the prefix as idents and
        remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
        but that would be silly.

        Parameters
        ----------
        msg_list : a list of Message or bytes objects
            The message to be split.
        copy : bool
            flag determining whether the arguments are bytes or Messages

        Returns
        -------
        (idents, msg_list) : two lists
            idents will always be a list of bytes, each of which is a ZMQ
            identity. msg_list will be a list of bytes or zmq.Messages of the
            form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
            should be unpackable/unserializable via self.deserialize at this
            point.
        """
        if copy:
            idx = msg_list.index(DELIM)
            return msg_list[:idx], msg_list[idx + 1:]
        else:
            failed = True
            for idx, m in enumerate(msg_list):
                if m.bytes == DELIM:
                    failed = False
                    break
            if failed:
                raise ValueError("DELIM not in msg_list")
            idents, msg_list = msg_list[:idx], msg_list[idx + 1:]
            return [m.bytes for m in idents], msg_list

    def _add_digest(self, signature):
        """add a digest to history to protect against replay attacks"""
        if self.digest_history_size == 0:
            # no history, never add digests
            return

        self.digest_history.add(signature)
        if len(self.digest_history) > self.digest_history_size:
            # threshold reached, cull 10%
            self._cull_digest_history()

    def _cull_digest_history(self):
        """cull the digest history

        Removes a randomly selected 10% of the digest history
        """
        current = len(self.digest_history)
        n_to_cull = max(int(current // 10), current - self.digest_history_size)
        if n_to_cull >= current:
            self.digest_history = set()
            return
        to_cull = random.sample(tuple(sorted(self.digest_history)), n_to_cull)
        self.digest_history.difference_update(to_cull)

    def deserialize(self, msg_list, content=True, copy=True):
        """Unserialize a msg_list to a nested message dict.

        This is roughly the inverse of serialize. The serialize/deserialize
        methods work with full message lists, whereas pack/unpack work with
        the individual message parts in the message list.

        Parameters
        ----------
        msg_list : list of bytes or Message objects
            The list of message parts of the form [HMAC,p_header,p_parent,
            p_metadata,p_content,buffer1,buffer2,...].
        content : bool (True)
            Whether to unpack the content dict (True), or leave it packed
            (False).
        copy : bool (True)
            Whether msg_list contains bytes (True) or the non-copying Message
            objects in each place (False).

        Returns
        -------
        msg : dict
            The nested message dict with top-level keys [header, parent_header,
            content, buffers].  The buffers are returned as memoryviews.
        """
        minlen = 5
        message = {}
        if not copy:
            # pyzmq didn't copy the first parts of the message, so we'll do it
            for i in range(minlen):
                msg_list[i] = msg_list[i].bytes
        if self.auth is not None:
            signature = msg_list[0]
            if not signature:
                raise ValueError("Unsigned Message")
            if signature in self.digest_history:
                raise ValueError("Duplicate Signature: %r" % signature)
            if content:
                # Only store signature if we are unpacking content, don't store if just peeking.
                self._add_digest(signature)
            check = self.sign(msg_list[1:5])
            if not compare_digest(signature, check):
                raise ValueError("Invalid Signature: %r" % signature)
        if not len(msg_list) >= minlen:
            raise TypeError(
                "malformed message, must have at least %i elements" % minlen)
        header = self.unpack(msg_list[1])
        message['header'] = extract_dates(header)
        message['msg_id'] = header['msg_id']
        message['msg_type'] = header['msg_type']
        message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
        message['metadata'] = self.unpack(msg_list[3])
        if content:
            message['content'] = self.unpack(msg_list[4])
        else:
            message['content'] = msg_list[4]
        buffers = [memoryview(b) for b in msg_list[5:]]
        if buffers and buffers[0].shape is None:
            # force copy to workaround pyzmq #646
            buffers = [memoryview(b.bytes) for b in msg_list[5:]]
        message['buffers'] = buffers
        if self.debug:
            pprint.pprint(message)
        # adapt to the current version
        return adapt(message)

    def unserialize(self, *args, **kwargs):
        warnings.warn(
            "Session.unserialize is deprecated. Use Session.deserialize.",
            DeprecationWarning,
        )
        return self.deserialize(*args, **kwargs)
Exemplo n.º 19
0
class WrapSpawner(Spawner):

    # Grab this from constructor args in case some Spawner ever wants it
    config = Any()

    child_class = Type(
        LocalProcessSpawner,
        Spawner,
        config=True,
        help="""The class to wrap for spawning single-user servers.
                Should be a subclass of Spawner.
                """)

    child_config = Dict(
        default_value={},
        config=True,
        help="Dictionary of config values to apply to wrapped spawner class.")

    child_state = Dict(default_value={})

    child_spawner = Instance(Spawner, allow_none=True)

    def construct_child(self):
        if self.child_spawner is None:
            self.child_spawner = self.child_class(
                user=self.user,
                db=self.db,
                hub=self.hub,
                authenticator=self.authenticator,
                config=self.config,
                **self.child_config)
            # initial state will always be wrong since it will see *our* state
            self.child_spawner.clear_state()
            if self.child_state:
                self.child_spawner.load_state(self.child_state)
            self.child_spawner.api_token = self.api_token
        return self.child_spawner

    def load_child_class(self, state):
        # Subclasses must arrange for correct child_class setting from load_state
        pass

    def load_state(self, state):
        super().load_state(state)
        self.load_child_class(state)
        self.child_config.update(state.get('child_conf', {}))
        self.child_state = state.get('child_state', {})
        self.construct_child()

    def get_state(self):
        state = super().get_state()
        state['child_conf'] = self.child_config
        if self.child_spawner:
            self.child_state = state[
                'child_state'] = self.child_spawner.get_state()
        return state

    def clear_state(self):
        super().clear_state()
        if self.child_spawner:
            self.child_spawner.clear_state()
        self.child_state = {}
        self.child_spawner = None

    # proxy functions for start/poll/stop
    # pass back the child's Future, or create a dummy if needed

    def start(self):
        if not self.child_spawner:
            self.construct_child()
        return self.child_spawner.start()

    def stop(self, now=False):
        if self.child_spawner:
            return self.child_spawner.stop(now)
        else:
            return _yield_val()

    def poll(self):
        if self.child_spawner:
            return self.child_spawner.poll()
        else:
            return _yield_val(1)
Exemplo n.º 20
0
class Figure(GMapsWidgetMixin, widgets.DOMWidget):
    """
    Figure widget

    This is the base widget for a Figure. Prefer instantiating
    instances of ``Figure`` using the :func:`gmaps.figure`
    factory method.
    """
    _view_name = Unicode("FigureView").tag(sync=True)
    _model_name = Unicode("FigureModel").tag(sync=True)
    _toolbar = Instance(Toolbar, allow_none=True, default=None).tag(
        sync=True, **widgets.widget_serialization)
    _errors_box = Instance(ErrorsBox, allow_none=True, default=None).tag(
        sync=True, **widgets.widget_serialization)
    _map = Instance(Map).tag(sync=True, **widgets.widget_serialization)

    def add_layer(self, layer):
        """
        Add a data layer to this figure.

        :param layer: a `gmaps` layer.

        :Examples:

        >>> f = figure()
        >>> fig.add_layer(gmaps.heatmap_layer(locations))

        .. seealso:: layer creation functions

            :func:`gmaps.heatmap_layer`
                Create a heatmap layer

            :func:`gmaps.symbol_layer`
                Create a layer of symbols

            :func:`gmaps.marker_layer`
                Create a layer of markers

            :func:`gmaps.geojson_layer`
                Create a GeoJSON layer

            :func:`gmaps.drawing_layer`
                Create a layer of custom features, and allow users to draw
                on the map

            :func:`gmaps.directions_layer`
                Create a layer with directions

            :func:`gmaps.bicycling_layer`
                Create a layer showing cycle routes

            :func:`gmaps.transit_layer`
                Create a layer showing public transport

            :func:`gmaps.traffic_layer`
                Create a layer showing current traffic information
        """
        try:
            toolbar_controls = layer.toolbar_controls
            if self._toolbar is not None and toolbar_controls is not None:
                self._toolbar.add_controls(toolbar_controls)
        except AttributeError:
            pass
        self._map.add_layer(layer)
Exemplo n.º 21
0
class Authenticator(LoggingConfigurable):
    """Base class for authenticators.

    An authenticator manages authenticating user API requests.

    Subclasses must define ``authenticate``, and may optionally also define
    ``setup``, ``cleanup`` and ``pre_response``.
    """

    cookie_name = Unicode(
        help="The cookie name to use for caching authentication information.",
        config=True,
    )

    @default("cookie_name")
    def _default_cookie_name(self):
        return "dask-gateway-%s" % uuid.uuid4().hex

    cache_max_age = Integer(
        300,
        help="""The maximum time in seconds to cache authentication information.

        Helps reduce load on the backing authentication service by caching
        responses between requests. After this time the user will need to be
        reauthenticated before making additional requests (note this is usually
        transparent to the user).
        """,
        config=True,
    )

    cache = Instance(UserCache)

    @default("cache")
    def _default_cache(self):
        return UserCache(max_age=self.cache_max_age)

    async def authenticate_and_handle(self, request, handler):
        # Try to authenticate with the cookie first
        cookie = request.cookies.get(self.cookie_name)
        if cookie is not None:
            user = self.cache.get(cookie)
            if user is not None:
                request["user"] = user
                return await handler(request)

        # Otherwise go through full authentication process
        user = await self.authenticate(request)
        if type(user) is tuple:
            user, context = user
        else:
            context = None

        request["user"] = user
        response = await handler(request)

        await self.pre_response(request, response, context)
        cookie = self.cache.put(user)
        response.set_cookie(self.cookie_name,
                            cookie,
                            max_age=self.cache_max_age * 2)

        return response

    async def setup(self, app):
        """Called when the server is starting up.

        Do any initialization here.

        Parameters
        ----------
        app : aiohttp.web.Application
            The aiohttp application. Can be used to add additional routes if
            needed.
        """
        pass

    async def cleanup(self):
        """Called when the server is shutting down.

        Do any cleanup here."""
        pass

    async def authenticate(self, request):
        """Perform the authentication process.

        Parameters
        ----------
        request : aiohttp.web.Request
            The current request.

        Returns
        -------
        user : User
            The authenticated user.
        context : object, optional
            If necessary, may optionally return an opaque object storing
            additional context needed to complete the authentication process.
            This will be passed to ``pre_response``.
        """
        raise NotImplementedError

    async def pre_response(self, request, response, context=None):
        """Called before returning a response.

        Allows modifying the outgoing response in-place to add additional
        headers, etc...

        Note that this is only called if ``authenticate`` was applied for this
        request.

        Parameters
        ----------
        request : aiohttp.web.Request
            The current request.
        response : aiohttp.web.Response
            The current response. May be modified in-place.
        context : object or None
            If present, the extra return value of ``authenticate``, providing
            any additional context needed to complete the authentication
            process.
        """
        pass
Exemplo n.º 22
0
class ContentsManager(LoggingConfigurable):
    """Base class for serving files and directories.

    This serves any text or binary file,
    as well as directories,
    with special handling for JSON notebook documents.

    Most APIs take a path argument,
    which is always an API-style unicode path,
    and always refers to a directory.

    - unicode, not url-escaped
    - '/'-separated
    - leading and trailing '/' will be stripped
    - if unspecified, path defaults to '',
      indicating the root path.

    """

    notary = Instance(sign.NotebookNotary)

    def _notary_default(self):
        return sign.NotebookNotary(parent=self)

    hide_globs = List(Unicode(), [
        u'__pycache__',
        '*.pyc',
        '*.pyo',
        '.DS_Store',
        '*.so',
        '*.dylib',
        '*~',
    ],
                      config=True,
                      help="""
        Glob patterns to hide in file and directory listings.
    """)

    untitled_notebook = Unicode(
        "Untitled",
        config=True,
        help="The base name used when creating untitled notebooks.")

    untitled_file = Unicode(
        "untitled",
        config=True,
        help="The base name used when creating untitled files.")

    untitled_directory = Unicode(
        "Untitled Folder",
        config=True,
        help="The base name used when creating untitled directories.")

    pre_save_hook = Any(None,
                        config=True,
                        help="""Python callable or importstring thereof

        To be called on a contents model prior to save.

        This can be used to process the structure,
        such as removing notebook outputs or other side effects that
        should not be saved.

        It will be called as (all arguments passed by keyword)::

            hook(path=path, model=model, contents_manager=self)

        - model: the model to be saved. Includes file contents.
          Modifying this dict will affect the file that is stored.
        - path: the API path of the save destination
        - contents_manager: this ContentsManager instance
        """)

    def _pre_save_hook_changed(self, name, old, new):
        if new and isinstance(new, string_types):
            self.pre_save_hook = import_item(self.pre_save_hook)
        elif new:
            if not callable(new):
                raise TraitError("pre_save_hook must be callable")

    def run_pre_save_hook(self, model, path, **kwargs):
        """Run the pre-save hook if defined, and log errors"""
        if self.pre_save_hook:
            try:
                self.log.debug("Running pre-save hook on %s", path)
                self.pre_save_hook(model=model,
                                   path=path,
                                   contents_manager=self,
                                   **kwargs)
            except Exception:
                self.log.error("Pre-save hook failed on %s",
                               path,
                               exc_info=True)

    checkpoints_class = Type(Checkpoints, config=True)
    checkpoints = Instance(Checkpoints, config=True)
    checkpoints_kwargs = Dict(config=True)

    def _checkpoints_default(self):
        return self.checkpoints_class(**self.checkpoints_kwargs)

    def _checkpoints_kwargs_default(self):
        return dict(
            parent=self,
            log=self.log,
        )

    # ContentsManager API part 1: methods that must be
    # implemented in subclasses.

    def dir_exists(self, path):
        """Does a directory exist at the given path?

        Like os.path.isdir

        Override this method in subclasses.

        Parameters
        ----------
        path : string
            The path to check

        Returns
        -------
        exists : bool
            Whether the path does indeed exist.
        """
        raise NotImplementedError

    def is_hidden(self, path):
        """Is path a hidden directory or file?

        Parameters
        ----------
        path : string
            The path to check. This is an API path (`/` separated,
            relative to root dir).

        Returns
        -------
        hidden : bool
            Whether the path is hidden.

        """
        raise NotImplementedError

    def file_exists(self, path=''):
        """Does a file exist at the given path?

        Like os.path.isfile

        Override this method in subclasses.

        Parameters
        ----------
        path : string
            The API path of a file to check for.

        Returns
        -------
        exists : bool
            Whether the file exists.
        """
        raise NotImplementedError('must be implemented in a subclass')

    def exists(self, path):
        """Does a file or directory exist at the given path?

        Like os.path.exists

        Parameters
        ----------
        path : string
            The API path of a file or directory to check for.

        Returns
        -------
        exists : bool
            Whether the target exists.
        """
        return self.file_exists(path) or self.dir_exists(path)

    def get(self, path, content=True, type=None, format=None):
        """Get a file or directory model."""
        raise NotImplementedError('must be implemented in a subclass')

    def save(self, model, path):
        """
        Save a file or directory model to path.

        Should return the saved model with no content.  Save implementations
        should call self.run_pre_save_hook(model=model, path=path) prior to
        writing any data.
        """
        raise NotImplementedError('must be implemented in a subclass')

    def delete_file(self, path):
        """Delete the file or directory at path."""
        raise NotImplementedError('must be implemented in a subclass')

    def rename_file(self, old_path, new_path):
        """Rename a file or directory."""
        raise NotImplementedError('must be implemented in a subclass')

    # ContentsManager API part 2: methods that have useable default
    # implementations, but can be overridden in subclasses.

    def delete(self, path):
        """Delete a file/directory and any associated checkpoints."""
        path = path.strip('/')
        if not path:
            raise HTTPError(400, "Can't delete root")
        self.delete_file(path)
        self.checkpoints.delete_all_checkpoints(path)

    def rename(self, old_path, new_path):
        """Rename a file and any checkpoints associated with that file."""
        self.rename_file(old_path, new_path)
        self.checkpoints.rename_all_checkpoints(old_path, new_path)

    def update(self, model, path):
        """Update the file's path

        For use in PATCH requests, to enable renaming a file without
        re-uploading its contents. Only used for renaming at the moment.
        """
        path = path.strip('/')
        new_path = model.get('path', path).strip('/')
        if path != new_path:
            self.rename(path, new_path)
        model = self.get(new_path, content=False)
        return model

    def info_string(self):
        return "Serving contents"

    def get_kernel_path(self, path, model=None):
        """Return the API path for the kernel
        
        KernelManagers can turn this value into a filesystem path,
        or ignore it altogether.

        The default value here will start kernels in the directory of the
        notebook server. FileContentsManager overrides this to use the
        directory containing the notebook.
        """
        return ''

    def increment_filename(self, filename, path='', insert=''):
        """Increment a filename until it is unique.

        Parameters
        ----------
        filename : unicode
            The name of a file, including extension
        path : unicode
            The API path of the target's directory

        Returns
        -------
        name : unicode
            A filename that is unique, based on the input filename.
        """
        path = path.strip('/')
        basename, ext = os.path.splitext(filename)
        for i in itertools.count():
            if i:
                insert_i = '{}{}'.format(insert, i)
            else:
                insert_i = ''
            name = u'{basename}{insert}{ext}'.format(basename=basename,
                                                     insert=insert_i,
                                                     ext=ext)
            if not self.exists(u'{}/{}'.format(path, name)):
                break
        return name

    def validate_notebook_model(self, model):
        """Add failed-validation message to model"""
        try:
            validate(model['content'])
        except ValidationError as e:
            model['message'] = u'Notebook Validation failed: {}:\n{}'.format(
                e.message,
                json.dumps(e.instance,
                           indent=1,
                           default=lambda obj: '<UNKNOWN>'),
            )
        return model

    def new_untitled(self, path='', type='', ext=''):
        """Create a new untitled file or directory in path
        
        path must be a directory
        
        File extension can be specified.
        
        Use `new` to create files with a fully specified path (including filename).
        """
        path = path.strip('/')
        if not self.dir_exists(path):
            raise HTTPError(404, 'No such directory: %s' % path)

        model = {}
        if type:
            model['type'] = type

        if ext == '.ipynb':
            model.setdefault('type', 'notebook')
        else:
            model.setdefault('type', 'file')

        insert = ''
        if model['type'] == 'directory':
            untitled = self.untitled_directory
            insert = ' '
        elif model['type'] == 'notebook':
            untitled = self.untitled_notebook
            ext = '.ipynb'
        elif model['type'] == 'file':
            untitled = self.untitled_file
        else:
            raise HTTPError(400, "Unexpected model type: %r" % model['type'])

        name = self.increment_filename(untitled + ext, path, insert=insert)
        path = u'{0}/{1}'.format(path, name)
        return self.new(model, path)

    def new(self, model=None, path=''):
        """Create a new file or directory and return its model with no content.
        
        To create a new untitled entity in a directory, use `new_untitled`.
        """
        path = path.strip('/')
        if model is None:
            model = {}

        if path.endswith('.ipynb'):
            model.setdefault('type', 'notebook')
        else:
            model.setdefault('type', 'file')

        # no content, not a directory, so fill out new-file model
        if 'content' not in model and model['type'] != 'directory':
            if model['type'] == 'notebook':
                model['content'] = new_notebook()
                model['format'] = 'json'
            else:
                model['content'] = ''
                model['type'] = 'file'
                model['format'] = 'text'

        model = self.save(model, path)
        return model

    def copy(self, from_path, to_path=None):
        """Copy an existing file and return its new model.

        If to_path not specified, it will be the parent directory of from_path.
        If to_path is a directory, filename will increment `from_path-Copy#.ext`.

        from_path must be a full path to a file.
        """
        path = from_path.strip('/')
        if to_path is not None:
            to_path = to_path.strip('/')

        if '/' in path:
            from_dir, from_name = path.rsplit('/', 1)
        else:
            from_dir = ''
            from_name = path

        model = self.get(path)
        model.pop('path', None)
        model.pop('name', None)
        if model['type'] == 'directory':
            raise HTTPError(400, "Can't copy directories")

        if to_path is None:
            to_path = from_dir
        if self.dir_exists(to_path):
            name = copy_pat.sub(u'.', from_name)
            to_name = self.increment_filename(name, to_path, insert='-Copy')
            to_path = u'{0}/{1}'.format(to_path, to_name)

        model = self.save(model, to_path)
        return model

    def log_info(self):
        self.log.info(self.info_string())

    def trust_notebook(self, path):
        """Explicitly trust a notebook

        Parameters
        ----------
        path : string
            The path of a notebook
        """
        model = self.get(path)
        nb = model['content']
        self.log.warn("Trusting notebook %s", path)
        self.notary.mark_cells(nb, True)
        self.save(model, path)

    def check_and_sign(self, nb, path=''):
        """Check for trusted cells, and sign the notebook.

        Called as a part of saving notebooks.

        Parameters
        ----------
        nb : dict
            The notebook dict
        path : string
            The notebook's path (for logging)
        """
        if self.notary.check_cells(nb):
            self.notary.sign(nb)
        else:
            self.log.warn("Saving untrusted notebook %s", path)

    def mark_trusted_cells(self, nb, path=''):
        """Mark cells as trusted if the notebook signature matches.

        Called as a part of loading notebooks.

        Parameters
        ----------
        nb : dict
            The notebook object (in current nbformat)
        path : string
            The notebook's path (for logging)
        """
        trusted = self.notary.check_signature(nb)
        if not trusted:
            self.log.warn("Notebook %s is not trusted", path)
        self.notary.mark_cells(nb, trusted)

    def should_list(self, name):
        """Should this file/directory name be displayed in a listing?"""
        return not any(fnmatch(name, glob) for glob in self.hide_globs)

    # Part 3: Checkpoints API
    def create_checkpoint(self, path):
        """Create a checkpoint."""
        return self.checkpoints.create_checkpoint(self, path)

    def restore_checkpoint(self, checkpoint_id, path):
        """
        Restore a checkpoint.
        """
        self.checkpoints.restore_checkpoint(self, checkpoint_id, path)

    def list_checkpoints(self, path):
        return self.checkpoints.list_checkpoints(path)

    def delete_checkpoint(self, checkpoint_id, path):
        return self.checkpoints.delete_checkpoint(checkpoint_id, path)
Exemplo n.º 23
0
class DataWidget(SimpleWidget):
    d = Instance(DataInstance).tag(sync=True,
                                   to_json=mview_serializer,
                                   from_json=deserializer)
Exemplo n.º 24
0
class LocalSubprocessEnvironment(RequestManager, Environment):
    """
    Implements functions for a local subprocess environment.
    """
    _provider_meta: ScriptProviderInfo

    write: Connection
    read: Connection
    provider: ScriptProvider = Instance(ScriptProvider)

    _framebuffer: Array
    _framebuffer_lock: Lock

    queue: Queue
    stopped: Event
    responder: Responder
    commands: ConvertingMappingProxy[str, Callable[..., Any], Callable[[Any],
                                                                       Future]]

    def additional_extensions(self) -> List[str]:
        """
        Defines additional extensions that should be
        loaded inside the environment
        """
        result = []

        for ext in self._provider_meta.extensions:
            if not ext.startswith("="):
                ext = "=" + ext

            name, extension = [s.strip() for s in ext.split("=")]
            extension = import_item(extension)

            if name:
                extension._name = name
            result.append(extension)

        return result

    def post_extension_load(self) -> None:
        """
        Called directly after extensions have been loaded
        (but not enabled)
        """
        self.queue = Queue()
        self.stopped = Event()
        self.handlers = {}
        self.commands = ConvertingMappingProxy(self.handlers, self._wrap2queue)

        # Let's initialize it here.
        provider_class = import_item(self._provider_meta.providercls)

        self.provider = provider_class(self.parent,
                                       **self._provider_meta.providerparams)

    def _wrap2queue(
            self, unwrapped: Callable[[Any], Any]) -> Callable[[Any], Future]:
        @functools.wraps(unwrapped)
        def _wrapper(data: Any) -> Future:
            fut = Future()
            self.queue.put(RequestQueueItem(fut, unwrapped, data))
            return fut

        return _wrapper

    def initialize(self) -> None:
        """
        Called by yuuno to tell it that yuuno has
        initialized to the point that it can now initialize
        interoperability for the given environment.
        """
        self.provider.initialize(self)
        self.handlers.update(
            BasicCommands(self.provider.get_script(), self).commands)
        self._framebuffer_lock = Lock()

    @contextmanager
    def framebuffer(self):
        with self._framebuffer_lock:
            yield memoryview(self._framebuffer).cast("B")

    def _copy_result(self, source: Future, destination: Future):
        def _done(_):
            if source.exception() is not None:
                destination.set_exception(source.exception())
            else:
                destination.set_result(source.result())
            self.queue.task_done()

        source.add_done_callback(_done)

    def run(self):
        """
        Wait for commands.
        """

        self.responder = Responder(self.read, self.write, self.commands)
        self.responder.start()

        self.responder.send(None)
        while not self.stopped.set():
            try:
                rqi: RequestQueueItem = self.queue.get(timeout=1)
            except Empty:
                continue

            if not rqi.future.set_running_or_notify_cancel():
                continue

            try:
                result = rqi.cb(**rqi.args)
            except KeyboardInterrupt:
                self.stop()
            except Exception as e:
                rqi.future.set_exception(e)
            else:
                if not isinstance(result, Future):
                    rqi.future.set_result(result)
                    self.queue.task_done()
                else:
                    self._copy_result(result, rqi.future)

        while True:
            try:
                rqi: RequestQueueItem = self.queue.get_nowait()
            except Empty:
                break

            rqi.future.set_exception(RuntimeError("System stoppped."))

        self.responder.stop()

    def stop(self):
        self.stopped.set()

    def deinitialize(self) -> None:
        """
        Called by yuuno before it deconfigures itself.
        """
        self.provider.deinitialize()

    @staticmethod
    def _preload():
        print(os.getpid(), ">", "Preloading.")
        from yuuno import init_standalone
        y = init_standalone()
        y.start()
        y.stop()
        print(os.getpid(), ">", "Preload complete.")

    @staticmethod
    def _check_parent():
        import psutil
        current = psutil.Process()
        current.parent().wait()
        print(os.getpid(), ">", "Parent died. Kill own process...")
        current.kill()

    @classmethod
    def execute(cls, read: Connection, write: Connection, framebuffer: Array):
        cls._preload()
        Thread(target=cls._check_parent, daemon=True).start()

        from yuuno import Yuuno
        yuuno = Yuuno.instance(parent=None)
        env = cls(parent=yuuno, read=read, write=write)
        env._framebuffer = framebuffer
        yuuno.environment = env

        # Wait for the ProviderMeta to be set.
        print(os.getpid(), ">", "Ready to deploy!")
        env._provider_meta = read.recv()
        yuuno.start()
        print(os.getpid(), ">", "Deployed", env._provider_meta)

        # Run the environment
        env.run()

        # Stop Yuuno.
        yuuno.stop()
Exemplo n.º 25
0
class Range(widgets.Widget):
    value = Union([List(), List(Instance(list))],
                  default_value=[0, 1]).tag(sync=True)
Exemplo n.º 26
0
class NbConvertApp(JupyterApp):
    """Application used to convert from notebook file type (``*.ipynb``)"""

    version = __version__
    name = 'jupyter-nbconvert'
    aliases = nbconvert_aliases
    flags = nbconvert_flags

    @default('log_level')
    def _log_level_default(self):
        return logging.INFO

    classes = List()

    @default('classes')
    def _classes_default(self):
        classes = [NbConvertBase]
        for pkg in (exporters, preprocessors, writers, postprocessors):
            for name in dir(pkg):
                cls = getattr(pkg, name)
                if isinstance(cls, type) and issubclass(cls, Configurable):
                    classes.append(cls)

        return classes

    description = Unicode(
        u"""This application is used to convert notebook files (*.ipynb)
        to various other formats.

        WARNING: THE COMMANDLINE INTERFACE MAY CHANGE IN FUTURE RELEASES.""")

    output_base = Unicode('',
                          help='''overwrite base name use for output files.
            can only be used when converting one notebook at a time.
            ''').tag(config=True)

    use_output_suffix = Bool(
        True,
        help="""Whether to apply a suffix prior to the extension (only relevant
            when converting to notebook format). The suffix is determined by
            the exporter, and is usually '.nbconvert'.""").tag(config=True)

    output_files_dir = Unicode(
        '{notebook_name}_files',
        help='''Directory to copy extra files (figures) to.
               '{notebook_name}' in the string will be converted to notebook
               basename.''').tag(config=True)

    examples = Unicode(u"""
        The simplest way to use nbconvert is
        
        > jupyter nbconvert mynotebook.ipynb --to html

        Options include {formats}.
        
        > jupyter nbconvert --to latex mynotebook.ipynb

        Both HTML and LaTeX support multiple output templates. LaTeX includes
        'base', 'article' and 'report'.  HTML includes 'basic' and 'full'. You
        can specify the flavor of the format used.

        > jupyter nbconvert --to html --template lab mynotebook.ipynb
        
        You can also pipe the output to stdout, rather than a file
        
        > jupyter nbconvert mynotebook.ipynb --stdout

        PDF is generated via latex

        > jupyter nbconvert mynotebook.ipynb --to pdf
        
        You can get (and serve) a Reveal.js-powered slideshow
        
        > jupyter nbconvert myslides.ipynb --to slides --post serve
        
        Multiple notebooks can be given at the command line in a couple of 
        different ways:
  
        > jupyter nbconvert notebook*.ipynb
        > jupyter nbconvert notebook1.ipynb notebook2.ipynb
        
        or you can specify the notebooks list in a config file, containing::
        
            c.NbConvertApp.notebooks = ["my_notebook.ipynb"]
        
        > jupyter nbconvert --config mycfg.py
        """.format(formats=get_export_names()))

    # Writer specific variables
    writer = Instance('nbconvert.writers.base.WriterBase',
                      help="""Instance of the writer class used to write the 
                      results of the conversion.""",
                      allow_none=True)
    writer_class = DottedObjectName('FilesWriter',
                                    help="""Writer class used to write the 
                                    results of the conversion""").tag(
                                        config=True)
    writer_aliases = {
        'fileswriter': 'nbconvert.writers.files.FilesWriter',
        'debugwriter': 'nbconvert.writers.debug.DebugWriter',
        'stdoutwriter': 'nbconvert.writers.stdout.StdoutWriter'
    }
    writer_factory = Type(allow_none=True)

    @observe('writer_class')
    def _writer_class_changed(self, change):
        new = change['new']
        if new.lower() in self.writer_aliases:
            new = self.writer_aliases[new.lower()]
        self.writer_factory = import_item(new)

    # Post-processor specific variables
    postprocessor = Instance(
        'nbconvert.postprocessors.base.PostProcessorBase',
        help="""Instance of the PostProcessor class used to write the
                      results of the conversion.""",
        allow_none=True)

    postprocessor_class = DottedOrNone(
        help="""PostProcessor class used to write the
                                    results of the conversion""").tag(
            config=True)
    postprocessor_aliases = {
        'serve': 'nbconvert.postprocessors.serve.ServePostProcessor'
    }
    postprocessor_factory = Type(None, allow_none=True)

    @observe('postprocessor_class')
    def _postprocessor_class_changed(self, change):
        new = change['new']
        if new.lower() in self.postprocessor_aliases:
            new = self.postprocessor_aliases[new.lower()]
        if new:
            self.postprocessor_factory = import_item(new)

    jupyter_widgets_base_url = Unicode(
        "https://unpkg.com/",
        help="URL base for Jupyter widgets").tag(config=True)
    html_manager_semver_range = Unicode(
        '*',
        help="Semver range for Jupyter widgets HTML manager").tag(config=True)

    export_format = Unicode(
        allow_none=False,
        help="""The export format to be used, either one of the built-in formats
        {formats}
        or a dotted object name that represents the import path for an
        `Exporter` class""".format(formats=get_export_names())).tag(
            config=True)

    notebooks = List([],
                     help="""List of notebooks to convert.
                     Wildcards are supported.
                     Filenames passed positionally will be added to the list.
                     """).tag(config=True)
    from_stdin = Bool(
        False, help="read a single notebook from stdin.").tag(config=True)

    @catch_config_error
    def initialize(self, argv=None):
        """Initialize application, notebooks, writer, and postprocessor"""
        # See https://bugs.python.org/issue37373 :(
        if sys.version_info[0] == 3 and sys.version_info[
                1] >= 8 and sys.platform.startswith('win'):
            asyncio.set_event_loop_policy(
                asyncio.WindowsSelectorEventLoopPolicy())

        self.init_syspath()
        super().initialize(argv)
        self.init_notebooks()
        self.init_writer()
        self.init_postprocessor()

    def init_syspath(self):
        """Add the cwd to the sys.path ($PYTHONPATH)"""
        sys.path.insert(0, os.getcwd())

    def init_notebooks(self):
        """Construct the list of notebooks.

        If notebooks are passed on the command-line,
        they override (rather than add) notebooks specified in config files.
        Glob each notebook to replace notebook patterns with filenames.
        """

        # Specifying notebooks on the command-line overrides (rather than
        # adds) the notebook list
        if self.extra_args:
            patterns = self.extra_args
        else:
            patterns = self.notebooks

        # Use glob to replace all the notebook patterns with filenames.
        filenames = []
        for pattern in patterns:

            # Use glob to find matching filenames.  Allow the user to convert
            # notebooks without having to type the extension.
            globbed_files = glob.glob(pattern)
            globbed_files.extend(glob.glob(pattern + '.ipynb'))
            if not globbed_files:
                self.log.warning("pattern %r matched no files", pattern)

            for filename in globbed_files:
                if not filename in filenames:
                    filenames.append(filename)
        self.notebooks = filenames

    def init_writer(self):
        """Initialize the writer (which is stateless)"""
        self._writer_class_changed({'new': self.writer_class})
        self.writer = self.writer_factory(parent=self)
        if hasattr(self.writer,
                   'build_directory') and self.writer.build_directory != '':
            self.use_output_suffix = False

    def init_postprocessor(self):
        """Initialize the postprocessor (which is stateless)"""
        self._postprocessor_class_changed({'new': self.postprocessor_class})
        if self.postprocessor_factory:
            self.postprocessor = self.postprocessor_factory(parent=self)

    def start(self):
        """Run start after initialization process has completed"""
        super().start()
        self.convert_notebooks()

    def init_single_notebook_resources(self, notebook_filename):
        """Step 1: Initialize resources

        This initializes the resources dictionary for a single notebook.

        Returns
        -------
        dict
            resources dictionary for a single notebook that MUST include the following keys:
                - config_dir: the location of the Jupyter config directory
                - unique_key: the notebook name
                - output_files_dir: a directory where output files (not
                  including the notebook itself) should be saved
        """
        basename = os.path.basename(notebook_filename)
        notebook_name = basename[:basename.rfind('.')]
        if self.output_base:
            # strip duplicate extension from output_base, to avoid Basename.ext.ext
            if getattr(self.exporter, 'file_extension', False):
                base, ext = os.path.splitext(self.output_base)
                if ext == self.exporter.file_extension:
                    self.output_base = base
            notebook_name = self.output_base

        self.log.debug("Notebook name is '%s'", notebook_name)

        # first initialize the resources we want to use
        resources = {}
        resources['config_dir'] = self.config_dir
        resources['unique_key'] = notebook_name

        output_files_dir = (self.output_files_dir.format(
            notebook_name=notebook_name))

        resources['output_files_dir'] = output_files_dir
        resources['jupyter_widgets_base_url'] = self.jupyter_widgets_base_url
        resources['html_manager_semver_range'] = self.html_manager_semver_range

        return resources

    def export_single_notebook(self,
                               notebook_filename,
                               resources,
                               input_buffer=None):
        """Step 2: Export the notebook

        Exports the notebook to a particular format according to the specified
        exporter. This function returns the output and (possibly modified)
        resources from the exporter.

        Parameters
        ----------
        notebook_filename : str
            name of notebook file.
        resources : dict
        input_buffer :
            readable file-like object returning unicode.
            if not None, notebook_filename is ignored

        Returns
        -------
        output
        dict
            resources (possibly modified)
        """
        try:
            if input_buffer is not None:
                output, resources = self.exporter.from_file(
                    input_buffer, resources=resources)
            else:
                output, resources = self.exporter.from_filename(
                    notebook_filename, resources=resources)
        except ConversionException:
            self.log.error("Error while converting '%s'",
                           notebook_filename,
                           exc_info=True)
            self.exit(1)

        return output, resources

    def write_single_notebook(self, output, resources):
        """Step 3: Write the notebook to file

        This writes output from the exporter to file using the specified writer.
        It returns the results from the writer.

        Parameters
        ----------
        output :
        resources : dict
            resources for a single notebook including name, config directory
            and directory to save output

        Returns
        -------
        file
            results from the specified writer output of exporter
        """
        if 'unique_key' not in resources:
            raise KeyError(
                "unique_key MUST be specified in the resources, but it is not")

        notebook_name = resources['unique_key']
        if self.use_output_suffix and not self.output_base:
            notebook_name += resources.get('output_suffix', '')

        write_results = self.writer.write(output,
                                          resources,
                                          notebook_name=notebook_name)
        return write_results

    def postprocess_single_notebook(self, write_results):
        """Step 4: Post-process the written file

        Only used if a postprocessor has been specified. After the
        converted notebook is written to a file in Step 3, this post-processes
        the notebook.
        """
        # Post-process if post processor has been defined.
        if hasattr(self, 'postprocessor') and self.postprocessor:
            self.postprocessor(write_results)

    def convert_single_notebook(self, notebook_filename, input_buffer=None):
        """Convert a single notebook.

        Performs the following steps:

            1. Initialize notebook resources
            2. Export the notebook to a particular format
            3. Write the exported notebook to file
            4. (Maybe) postprocess the written file

        Parameters
        ----------
        notebook_filename : str
        input_buffer :
            If input_buffer is not None, conversion is done and the buffer is
            used as source into a file basenamed by the notebook_filename
            argument.
        """
        if input_buffer is None:
            self.log.info("Converting notebook %s to %s", notebook_filename,
                          self.export_format)
        else:
            self.log.info("Converting notebook into %s", self.export_format)

        resources = self.init_single_notebook_resources(notebook_filename)
        output, resources = self.export_single_notebook(
            notebook_filename, resources, input_buffer=input_buffer)
        write_results = self.write_single_notebook(output, resources)
        self.postprocess_single_notebook(write_results)

    def convert_notebooks(self):
        """Convert the notebooks in the self.notebook traitlet """
        # check that the output base isn't specified if there is more than
        # one notebook to convert
        if self.output_base != '' and len(self.notebooks) > 1:
            self.log.error("""
                UsageError: --output flag or `NbConvertApp.output_base` config option
                cannot be used when converting multiple notebooks.
                """)
            self.exit(1)

        # no notebooks to convert!
        if len(self.notebooks) == 0 and not self.from_stdin:
            self.print_help()
            sys.exit(-1)

        if not self.export_format:
            raise ValueError(
                "Please specify an output format with '--to <format>'."
                f"\nThe following formats are available: {get_export_names()}")

        # initialize the exporter
        cls = get_exporter(self.export_format)
        self.exporter = cls(config=self.config)

        # convert each notebook
        if not self.from_stdin:
            for notebook_filename in self.notebooks:
                self.convert_single_notebook(notebook_filename)
        else:
            input_buffer = unicode_stdin_stream()
            # default name when conversion from stdin
            self.convert_single_notebook("notebook.ipynb",
                                         input_buffer=input_buffer)

    def document_flag_help(self):
        """
        Return a string containing descriptions of all the flags.
        """
        flags = "The following flags are defined:\n\n"
        for flag, (cfg, fhelp) in self.flags.items():
            flags += "{}\n".format(flag)
            flags += indent(fill(fhelp, 80)) + '\n\n'
            flags += indent(fill("Long Form: " + str(cfg), 80)) + '\n\n'
        return flags

    def document_alias_help(self):
        """Return a string containing all of the aliases"""

        aliases = "The folowing aliases are defined:\n\n"
        for alias, longname in self.aliases.items():
            aliases += "\t**{}** ({})\n\n".format(alias, longname)
        return aliases

    def document_config_options(self):
        """
        Provides a much improves version of the configuration documentation by
        breaking the configuration options into app, exporter, writer,
        preprocessor, postprocessor, and other sections.
        """
        categories = {
            category: [
                c for c in self._classes_inc_parents()
                if category in c.__name__.lower()
            ]
            for category in
            ['app', 'exporter', 'writer', 'preprocessor', 'postprocessor']
        }
        accounted_for = {
            c
            for category in categories.values() for c in category
        }
        categories['other'] = [
            c for c in self._classes_inc_parents() if c not in accounted_for
        ]

        header = dedent("""
                        {section} Options
                        -----------------------

                        """)
        sections = ""
        for category in categories:
            sections += header.format(section=category.title())
            if category in ['exporter', 'preprocessor', 'writer']:
                sections += ".. image:: _static/{image}_inheritance.png\n\n".format(
                    image=category)
            sections += '\n'.join(c.class_config_rst_doc()
                                  for c in categories[category])

        return sections.replace(' : ', r' \: ')
Exemplo n.º 27
0
class NotebookApp(JupyterApp):

    name = 'jupyter-notebook'

    description = """
        The Jupyter HTML Notebook.
        
        This launches a Tornado based HTML Notebook Server that serves up an
        HTML5/Javascript Notebook client.
    """
    examples = _examples
    aliases = aliases
    flags = flags

    classes = [
        KernelManager,
        Session,
        MappingKernelManager,
        ContentsManager,
        FileContentsManager,
        NotebookNotary,
        KernelSpecManager,
    ]
    flags = Dict(flags)
    aliases = Dict(aliases)

    subcommands = dict(list=(NbserverListApp,
                             NbserverListApp.description.splitlines()[0]), )

    _log_formatter_cls = LogFormatter

    def _log_level_default(self):
        return logging.INFO

    def _log_datefmt_default(self):
        """Exclude date from default date format"""
        return "%H:%M:%S"

    def _log_format_default(self):
        """override default log format to include time"""
        return u"%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s]%(end_color)s %(message)s"

    # create requested profiles by default, if they don't exist:
    auto_create = Bool(True)

    # file to be opened in the notebook server
    file_to_run = Unicode('', config=True)

    # Network related information

    allow_origin = Unicode('',
                           config=True,
                           help="""Set the Access-Control-Allow-Origin header
        
        Use '*' to allow any origin to access your server.
        
        Takes precedence over allow_origin_pat.
        """)

    allow_origin_pat = Unicode(
        '',
        config=True,
        help=
        """Use a regular expression for the Access-Control-Allow-Origin header
        
        Requests from an origin matching the expression will get replies with:
        
            Access-Control-Allow-Origin: origin
        
        where `origin` is the origin of the request.
        
        Ignored if allow_origin is set.
        """)

    allow_credentials = Bool(
        False,
        config=True,
        help="Set the Access-Control-Allow-Credentials: true header")

    default_url = Unicode('/tree',
                          config=True,
                          help="The default URL to redirect to from `/`")

    ip = Unicode('localhost',
                 config=True,
                 help="The IP address the notebook server will listen on.")

    def _ip_default(self):
        """Return localhost if available, 127.0.0.1 otherwise.
        
        On some (horribly broken) systems, localhost cannot be bound.
        """
        s = socket.socket()
        try:
            s.bind(('localhost', 0))
        except socket.error as e:
            self.log.warn(
                "Cannot bind to localhost, using 127.0.0.1 as default ip\n%s",
                e)
            return '127.0.0.1'
        else:
            s.close()
            return 'localhost'

    def _ip_changed(self, name, old, new):
        if new == u'*': self.ip = u''

    port = Integer(8888,
                   config=True,
                   help="The port the notebook server will listen on.")
    port_retries = Integer(
        50,
        config=True,
        help=
        "The number of additional ports to try if the specified port is not available."
    )

    certfile = Unicode(
        u'',
        config=True,
        help="""The full path to an SSL/TLS certificate file.""")

    keyfile = Unicode(
        u'',
        config=True,
        help="""The full path to a private key file for usage with SSL/TLS.""")

    cookie_secret_file = Unicode(
        config=True, help="""The file where the cookie secret is stored.""")

    def _cookie_secret_file_default(self):
        return os.path.join(self.runtime_dir, 'notebook_cookie_secret')

    cookie_secret = Bytes(b'',
                          config=True,
                          help="""The random bytes used to secure cookies.
        By default this is a new random number every time you start the Notebook.
        Set it to a value in a config file to enable logins to persist across server sessions.
        
        Note: Cookie secrets should be kept private, do not share config files with
        cookie_secret stored in plaintext (you can read the value from a file).
        """)

    def _cookie_secret_default(self):
        if os.path.exists(self.cookie_secret_file):
            with io.open(self.cookie_secret_file, 'rb') as f:
                return f.read()
        else:
            secret = base64.encodestring(os.urandom(1024))
            self._write_cookie_secret_file(secret)
            return secret

    def _write_cookie_secret_file(self, secret):
        """write my secret to my secret_file"""
        self.log.info("Writing notebook server cookie secret to %s",
                      self.cookie_secret_file)
        with io.open(self.cookie_secret_file, 'wb') as f:
            f.write(secret)
        try:
            os.chmod(self.cookie_secret_file, 0o600)
        except OSError:
            self.log.warn("Could not set permissions on %s",
                          self.cookie_secret_file)

    password = Unicode(u'',
                       config=True,
                       help="""Hashed password to use for web authentication.

                      To generate, type in a python/IPython shell:

                        from jupyter_notebook.auth import passwd; passwd()

                      The string should be of the form type:salt:hashed-password.
                      """)

    open_browser = Bool(True,
                        config=True,
                        help="""Whether to open in a browser after starting.
                        The specific browser used is platform dependent and
                        determined by the python standard library `webbrowser`
                        module, unless it is overridden using the --browser
                        (NotebookApp.browser) configuration option.
                        """)

    browser = Unicode(u'',
                      config=True,
                      help="""Specify what command to use to invoke a web
                      browser when opening the notebook. If not specified, the
                      default browser will be determined by the `webbrowser`
                      standard library module, which allows setting of the
                      BROWSER environment variable to override it.
                      """)

    webapp_settings = Dict(config=True,
                           help="DEPRECATED, use tornado_settings")

    def _webapp_settings_changed(self, name, old, new):
        self.log.warn(
            "\n    webapp_settings is deprecated, use tornado_settings.\n")
        self.tornado_settings = new

    tornado_settings = Dict(
        config=True,
        help="Supply overrides for the tornado.web.Application that the "
        "IPython notebook uses.")

    ssl_options = Dict(config=True,
                       help="""Supply SSL options for the tornado HTTPServer.
            See the tornado docs for details.""")

    jinja_environment_options = Dict(
        config=True,
        help="Supply extra arguments that will be passed to Jinja environment."
    )

    jinja_template_vars = Dict(
        config=True,
        help="Extra variables to supply to jinja templates when rendering.",
    )

    enable_mathjax = Bool(
        True,
        config=True,
        help="""Whether to enable MathJax for typesetting math/TeX

        MathJax is the javascript library IPython uses to render math/LaTeX. It is
        very large, so you may want to disable it if you have a slow internet
        connection, or for offline use of the notebook.

        When disabled, equations etc. will appear as their untransformed TeX source.
        """)

    def _enable_mathjax_changed(self, name, old, new):
        """set mathjax url to empty if mathjax is disabled"""
        if not new:
            self.mathjax_url = u''

    base_url = Unicode('/',
                       config=True,
                       help='''The base URL for the notebook server.

                               Leading and trailing slashes can be omitted,
                               and will automatically be added.
                               ''')

    def _base_url_changed(self, name, old, new):
        if not new.startswith('/'):
            self.base_url = '/' + new
        elif not new.endswith('/'):
            self.base_url = new + '/'

    base_project_url = Unicode('/',
                               config=True,
                               help="""DEPRECATED use base_url""")

    def _base_project_url_changed(self, name, old, new):
        self.log.warn("base_project_url is deprecated, use base_url")
        self.base_url = new

    extra_static_paths = List(
        Unicode,
        config=True,
        help="""Extra paths to search for serving static files.
        
        This allows adding javascript/css to be available from the notebook server machine,
        or overriding individual files in the IPython""")

    @property
    def static_file_path(self):
        """return extra paths + the default location"""
        return self.extra_static_paths + [DEFAULT_STATIC_FILES_PATH]

    static_custom_path = List(Unicode,
                              help="""Path to search for custom.js, css""")

    def _static_custom_path_default(self):
        return [
            os.path.join(d, 'custom') for d in (
                self.config_dir,
                # FIXME: serve IPython profile while we don't have `jupyter migrate`
                os.path.join(get_ipython_dir(), 'profile_default', 'static'),
                DEFAULT_STATIC_FILES_PATH)
        ]

    extra_template_paths = List(
        Unicode,
        config=True,
        help="""Extra paths to search for serving jinja templates.

        Can be used to override templates from jupyter_notebook.templates.""")

    @property
    def template_file_path(self):
        """return extra paths + the default locations"""
        return self.extra_template_paths + DEFAULT_TEMPLATE_PATH_LIST

    extra_nbextensions_path = List(
        Unicode,
        config=True,
        help="""extra paths to look for Javascript notebook extensions""")

    @property
    def nbextensions_path(self):
        """The path to look for Javascript notebook extensions"""
        path = self.extra_nbextensions_path + jupyter_path('nbextensions')
        # FIXME: remove IPython nbextensions path once migration is setup
        path.append(os.path.join(get_ipython_dir(), 'nbextensions'))
        return path

    websocket_url = Unicode("",
                            config=True,
                            help="""The base URL for websockets,
        if it differs from the HTTP server (hint: it almost certainly doesn't).
        
        Should be in the form of an HTTP origin: ws[s]://hostname[:port]
        """)
    mathjax_url = Unicode("", config=True, help="""The url for MathJax.js.""")

    def _mathjax_url_default(self):
        if not self.enable_mathjax:
            return u''
        static_url_prefix = self.tornado_settings.get(
            "static_url_prefix", url_path_join(self.base_url, "static"))
        return url_path_join(static_url_prefix, 'components', 'MathJax',
                             'MathJax.js')

    def _mathjax_url_changed(self, name, old, new):
        if new and not self.enable_mathjax:
            # enable_mathjax=False overrides mathjax_url
            self.mathjax_url = u''
        else:
            self.log.info("Using MathJax: %s", new)

    contents_manager_class = Type(default_value=FileContentsManager,
                                  klass=ContentsManager,
                                  config=True,
                                  help='The notebook manager class to use.')
    kernel_manager_class = Type(default_value=MappingKernelManager,
                                config=True,
                                help='The kernel manager class to use.')
    session_manager_class = Type(default_value=SessionManager,
                                 config=True,
                                 help='The session manager class to use.')
    cluster_manager_class = Type(default_value=ClusterManager,
                                 config=True,
                                 help='The cluster manager class to use.')

    config_manager_class = Type(default_value=ConfigManager,
                                config=True,
                                help='The config manager class to use')

    kernel_spec_manager = Instance(KernelSpecManager, allow_none=True)

    kernel_spec_manager_class = Type(default_value=KernelSpecManager,
                                     config=True,
                                     help="""
        The kernel spec manager class to use. Should be a subclass
        of `jupyter_client.kernelspec.KernelSpecManager`.

        The Api of KernelSpecManager is provisional and might change
        without warning between this version of IPython and the next stable one.
        """)

    login_handler_class = Type(
        default_value=LoginHandler,
        klass=web.RequestHandler,
        config=True,
        help='The login handler class to use.',
    )

    logout_handler_class = Type(
        default_value=LogoutHandler,
        klass=web.RequestHandler,
        config=True,
        help='The logout handler class to use.',
    )

    trust_xheaders = Bool(
        False,
        config=True,
        help=
        ("Whether to trust or not X-Scheme/X-Forwarded-Proto and X-Real-Ip/X-Forwarded-For headers"
         "sent by the upstream reverse proxy. Necessary if the proxy handles SSL"
         ))

    info_file = Unicode()

    def _info_file_default(self):
        info_file = "nbserver-%s.json" % os.getpid()
        return os.path.join(self.runtime_dir, info_file)

    pylab = Unicode('disabled',
                    config=True,
                    help="""
        DISABLED: use %pylab or %matplotlib in the notebook to enable matplotlib.
        """)

    def _pylab_changed(self, name, old, new):
        """when --pylab is specified, display a warning and exit"""
        if new != 'warn':
            backend = ' %s' % new
        else:
            backend = ''
        self.log.error(
            "Support for specifying --pylab on the command line has been removed."
        )
        self.log.error(
            "Please use `%pylab{0}` or `%matplotlib{0}` in the notebook itself."
            .format(backend))
        self.exit(1)

    notebook_dir = Unicode(
        config=True, help="The directory to use for notebooks and kernels.")

    def _notebook_dir_default(self):
        if self.file_to_run:
            return os.path.dirname(os.path.abspath(self.file_to_run))
        else:
            return py3compat.getcwd()

    def _notebook_dir_changed(self, name, old, new):
        """Do a bit of validation of the notebook dir."""
        if not os.path.isabs(new):
            # If we receive a non-absolute path, make it absolute.
            self.notebook_dir = os.path.abspath(new)
            return
        if not os.path.isdir(new):
            raise TraitError("No such notebook dir: %r" % new)

        # setting App.notebook_dir implies setting notebook and kernel dirs as well
        self.config.FileContentsManager.root_dir = new
        self.config.MappingKernelManager.root_dir = new

    server_extensions = List(
        Unicode(),
        config=True,
        help=(
            "Python modules to load as notebook server extensions. "
            "This is an experimental API, and may change in future releases."))

    reraise_server_extension_failures = Bool(
        False,
        config=True,
        help="Reraise exceptions encountered loading server extensions?",
    )

    def parse_command_line(self, argv=None):
        super(NotebookApp, self).parse_command_line(argv)

        if self.extra_args:
            arg0 = self.extra_args[0]
            f = os.path.abspath(arg0)
            self.argv.remove(arg0)
            if not os.path.exists(f):
                self.log.critical("No such file or directory: %s", f)
                self.exit(1)

            # Use config here, to ensure that it takes higher priority than
            # anything that comes from the profile.
            c = Config()
            if os.path.isdir(f):
                c.NotebookApp.notebook_dir = f
            elif os.path.isfile(f):
                c.NotebookApp.file_to_run = f
            self.update_config(c)

    def init_configurables(self):
        self.kernel_spec_manager = self.kernel_spec_manager_class(
            parent=self, )
        self.kernel_manager = self.kernel_manager_class(
            parent=self,
            log=self.log,
            connection_dir=self.runtime_dir,
            kernel_spec_manager=self.kernel_spec_manager,
        )
        self.contents_manager = self.contents_manager_class(
            parent=self,
            log=self.log,
        )
        self.session_manager = self.session_manager_class(
            parent=self,
            log=self.log,
            kernel_manager=self.kernel_manager,
            contents_manager=self.contents_manager,
        )
        self.cluster_manager = self.cluster_manager_class(
            parent=self,
            log=self.log,
        )

        self.config_manager = self.config_manager_class(
            parent=self,
            log=self.log,
            config_dir=self.config_dir,
        )

    def init_logging(self):
        # This prevents double log messages because tornado use a root logger that
        # self.log is a child of. The logging module dipatches log messages to a log
        # and all of its ancenstors until propagate is set to False.
        self.log.propagate = False

        for log in app_log, access_log, gen_log:
            # consistent log output name (NotebookApp instead of tornado.access, etc.)
            log.name = self.log.name
        # hook up tornado 3's loggers to our app handlers
        logger = logging.getLogger('tornado')
        logger.propagate = True
        logger.parent = self.log
        logger.setLevel(self.log.level)

    def init_webapp(self):
        """initialize tornado webapp and httpserver"""
        self.tornado_settings['allow_origin'] = self.allow_origin
        if self.allow_origin_pat:
            self.tornado_settings['allow_origin_pat'] = re.compile(
                self.allow_origin_pat)
        self.tornado_settings['allow_credentials'] = self.allow_credentials
        # ensure default_url starts with base_url
        if not self.default_url.startswith(self.base_url):
            self.default_url = url_path_join(self.base_url, self.default_url)

        self.web_app = NotebookWebApplication(
            self, self.kernel_manager, self.contents_manager,
            self.cluster_manager, self.session_manager,
            self.kernel_spec_manager, self.config_manager, self.log,
            self.base_url, self.default_url, self.tornado_settings,
            self.jinja_environment_options)
        ssl_options = self.ssl_options
        if self.certfile:
            ssl_options['certfile'] = self.certfile
        if self.keyfile:
            ssl_options['keyfile'] = self.keyfile
        if not ssl_options:
            # None indicates no SSL config
            ssl_options = None
        else:
            # Disable SSLv3, since its use is discouraged.
            ssl_options['ssl_version'] = ssl.PROTOCOL_TLSv1
        self.login_handler_class.validate_security(self,
                                                   ssl_options=ssl_options)
        self.http_server = httpserver.HTTPServer(self.web_app,
                                                 ssl_options=ssl_options,
                                                 xheaders=self.trust_xheaders)

        success = None
        for port in random_ports(self.port, self.port_retries + 1):
            try:
                self.http_server.listen(port, self.ip)
            except socket.error as e:
                if e.errno == errno.EADDRINUSE:
                    self.log.info(
                        'The port %i is already in use, trying another random port.'
                        % port)
                    continue
                elif e.errno in (errno.EACCES,
                                 getattr(errno, 'WSAEACCES', errno.EACCES)):
                    self.log.warn("Permission to listen on port %i denied" %
                                  port)
                    continue
                else:
                    raise
            else:
                self.port = port
                success = True
                break
        if not success:
            self.log.critical(
                'ERROR: the notebook server could not be started because '
                'no available port could be found.')
            self.exit(1)

    @property
    def display_url(self):
        ip = self.ip if self.ip else '[all ip addresses on your system]'
        return self._url(ip)

    @property
    def connection_url(self):
        ip = self.ip if self.ip else 'localhost'
        return self._url(ip)

    def _url(self, ip):
        proto = 'https' if self.certfile else 'http'
        return "%s://%s:%i%s" % (proto, ip, self.port, self.base_url)

    def init_terminals(self):
        try:
            from .terminal import initialize
            initialize(self.web_app)
            self.web_app.settings['terminals_available'] = True
        except ImportError as e:
            log = self.log.debug if sys.platform == 'win32' else self.log.warn
            log("Terminals not available (error was %s)", e)

    def init_signal(self):
        if not sys.platform.startswith('win'):
            signal.signal(signal.SIGINT, self._handle_sigint)
        signal.signal(signal.SIGTERM, self._signal_stop)
        if hasattr(signal, 'SIGUSR1'):
            # Windows doesn't support SIGUSR1
            signal.signal(signal.SIGUSR1, self._signal_info)
        if hasattr(signal, 'SIGINFO'):
            # only on BSD-based systems
            signal.signal(signal.SIGINFO, self._signal_info)

    def _handle_sigint(self, sig, frame):
        """SIGINT handler spawns confirmation dialog"""
        # register more forceful signal handler for ^C^C case
        signal.signal(signal.SIGINT, self._signal_stop)
        # request confirmation dialog in bg thread, to avoid
        # blocking the App
        thread = threading.Thread(target=self._confirm_exit)
        thread.daemon = True
        thread.start()

    def _restore_sigint_handler(self):
        """callback for restoring original SIGINT handler"""
        signal.signal(signal.SIGINT, self._handle_sigint)

    def _confirm_exit(self):
        """confirm shutdown on ^C
        
        A second ^C, or answering 'y' within 5s will cause shutdown,
        otherwise original SIGINT handler will be restored.
        
        This doesn't work on Windows.
        """
        info = self.log.info
        info('interrupted')
        print(self.notebook_info())
        sys.stdout.write("Shutdown this notebook server (y/[n])? ")
        sys.stdout.flush()
        r, w, x = select.select([sys.stdin], [], [], 5)
        if r:
            line = sys.stdin.readline()
            if line.lower().startswith('y') and 'n' not in line.lower():
                self.log.critical("Shutdown confirmed")
                ioloop.IOLoop.current().stop()
                return
        else:
            print("No answer for 5s:", end=' ')
        print("resuming operation...")
        # no answer, or answer is no:
        # set it back to original SIGINT handler
        # use IOLoop.add_callback because signal.signal must be called
        # from main thread
        ioloop.IOLoop.current().add_callback(self._restore_sigint_handler)

    def _signal_stop(self, sig, frame):
        self.log.critical("received signal %s, stopping", sig)
        ioloop.IOLoop.current().stop()

    def _signal_info(self, sig, frame):
        print(self.notebook_info())

    def init_components(self):
        """Check the components submodule, and warn if it's unclean"""
        status = submodule.check_submodule_status()
        if status == 'missing':
            self.log.warn(
                "components submodule missing, running `git submodule update`")
            submodule.update_submodules(submodule.repo_parent())
        elif status == 'unclean':
            self.log.warn(
                "components submodule unclean, you may see 404s on static/components"
            )
            self.log.warn(
                "run `setup.py submodule` or `git submodule update` to update")

    def init_kernel_specs(self):
        """Check that the IPython kernel is present, if available"""
        try:
            self.kernel_spec_manager.get_kernel_spec(NATIVE_KERNEL_NAME)
        except NoSuchKernel:
            try:
                import ipykernel
            except ImportError:
                self.log.warn("IPython kernel not available")
            else:
                self.log.warn("Installing IPython kernel spec")
                self.kernel_spec_manager.install_native_kernel_spec(user=True)

    def init_server_extensions(self):
        """Load any extensions specified by config.

        Import the module, then call the load_jupyter_server_extension function,
        if one exists.
        
        The extension API is experimental, and may change in future releases.
        """
        for modulename in self.server_extensions:
            try:
                mod = importlib.import_module(modulename)
                func = getattr(mod, 'load_jupyter_server_extension', None)
                if func is not None:
                    func(self)
            except Exception:
                if self.reraise_server_extension_failures:
                    raise
                self.log.warn("Error loading server extension %s",
                              modulename,
                              exc_info=True)

    @catch_config_error
    def initialize(self, argv=None):
        super(NotebookApp, self).initialize(argv)
        self.init_logging()
        self.init_configurables()
        self.init_components()
        self.init_webapp()
        self.init_kernel_specs()
        self.init_terminals()
        self.init_signal()
        self.init_server_extensions()

    def cleanup_kernels(self):
        """Shutdown all kernels.
        
        The kernels will shutdown themselves when this process no longer exists,
        but explicit shutdown allows the KernelManagers to cleanup the connection files.
        """
        self.log.info('Shutting down kernels')
        self.kernel_manager.shutdown_all()

    def notebook_info(self):
        "Return the current working directory and the server url information"
        info = self.contents_manager.info_string() + "\n"
        info += "%d active kernels \n" % len(self.kernel_manager._kernels)
        return info + "The IPython Notebook is running at: %s" % self.display_url

    def server_info(self):
        """Return a JSONable dict of information about this server."""
        return {
            'url': self.connection_url,
            'hostname': self.ip if self.ip else 'localhost',
            'port': self.port,
            'secure': bool(self.certfile),
            'base_url': self.base_url,
            'notebook_dir': os.path.abspath(self.notebook_dir),
            'pid': os.getpid()
        }

    def write_server_info_file(self):
        """Write the result of server_info() to the JSON file info_file."""
        with open(self.info_file, 'w') as f:
            json.dump(self.server_info(), f, indent=2)

    def remove_server_info_file(self):
        """Remove the nbserver-<pid>.json file created for this server.
        
        Ignores the error raised when the file has already been removed.
        """
        try:
            os.unlink(self.info_file)
        except OSError as e:
            if e.errno != errno.ENOENT:
                raise

    def start(self):
        """ Start the IPython Notebook server app, after initialization
        
        This method takes no arguments so all configuration and initialization
        must be done prior to calling this method."""
        if self.subapp is not None:
            return self.subapp.start()

        info = self.log.info
        for line in self.notebook_info().split("\n"):
            info(line)
        info(
            "Use Control-C to stop this server and shut down all kernels (twice to skip confirmation)."
        )

        self.write_server_info_file()

        if self.open_browser or self.file_to_run:
            try:
                browser = webbrowser.get(self.browser or None)
            except webbrowser.Error as e:
                self.log.warn('No web browser found: %s.' % e)
                browser = None

            if self.file_to_run:
                if not os.path.exists(self.file_to_run):
                    self.log.critical("%s does not exist" % self.file_to_run)
                    self.exit(1)

                relpath = os.path.relpath(self.file_to_run, self.notebook_dir)
                uri = url_path_join('notebooks', *relpath.split(os.sep))
            else:
                uri = 'tree'
            if browser:
                b = lambda: browser.open(
                    url_path_join(self.connection_url, uri), new=2)
                threading.Thread(target=b).start()

        self.io_loop = ioloop.IOLoop.current()
        if sys.platform.startswith('win'):
            # add no-op to wake every 5s
            # to handle signals that may be ignored by the inner loop
            pc = ioloop.PeriodicCallback(lambda: None, 5000)
            pc.start()
        try:
            self.io_loop.start()
        except KeyboardInterrupt:
            info("Interrupted...")
        finally:
            self.cleanup_kernels()
            self.remove_server_info_file()

    def stop(self):
        def _stop():
            self.http_server.stop()
            self.io_loop.stop()

        self.io_loop.add_callback(_stop)
Exemplo n.º 28
0
class IPKernelApp(BaseIPythonApplication, InteractiveShellApp,
                  ConnectionFileMixin):
    name = 'ipython-kernel'
    aliases = Dict(kernel_aliases)
    flags = Dict(kernel_flags)
    classes = [IPythonKernel, ZMQInteractiveShell, ProfileDir, Session]
    # the kernel class, as an importstring
    kernel_class = Type('ipykernel.ipkernel.IPythonKernel',
                        config=True,
                        klass='ipykernel.kernelbase.Kernel',
                        help="""The Kernel subclass to be used.

    This should allow easy re-use of the IPKernelApp entry point
    to configure and launch kernels other than IPython's own.
    """)
    kernel = Any()
    poller = Any(
    )  # don't restrict this even though current pollers are all Threads
    heartbeat = Instance(Heartbeat, allow_none=True)
    ports = Dict()

    subcommands = {
        'install': ('ipykernel.kernelspec.InstallIPythonKernelSpecApp',
                    'Install the IPython kernel'),
    }

    # connection info:
    connection_dir = Unicode()

    def _connection_dir_default(self):
        return jupyter_runtime_dir()

    @property
    def abs_connection_file(self):
        if os.path.basename(self.connection_file) == self.connection_file:
            return os.path.join(self.connection_dir, self.connection_file)
        else:
            return self.connection_file

    # streams, etc.
    no_stdout = Bool(False,
                     config=True,
                     help="redirect stdout to the null device")
    no_stderr = Bool(False,
                     config=True,
                     help="redirect stderr to the null device")
    outstream_class = DottedObjectName(
        'ipykernel.iostream.OutStream',
        config=True,
        help="The importstring for the OutStream factory")
    displayhook_class = DottedObjectName(
        'ipykernel.displayhook.ZMQDisplayHook',
        config=True,
        help="The importstring for the DisplayHook factory")

    # polling
    parent_handle = Integer(
        int(os.environ.get('JPY_PARENT_PID') or 0),
        config=True,
        help="""kill this process if its parent dies.  On Windows, the argument
        specifies the HANDLE of the parent process, otherwise it is simply boolean.
        """)
    interrupt = Integer(int(os.environ.get('JPY_INTERRUPT_EVENT') or 0),
                        config=True,
                        help="""ONLY USED ON WINDOWS
        Interrupt this process when the parent is signaled.
        """)

    def init_crash_handler(self):
        # Install minimal exception handling
        sys.excepthook = FormattedTB(mode='Verbose',
                                     color_scheme='NoColor',
                                     ostream=sys.__stdout__)

    def init_poller(self):
        if sys.platform == 'win32':
            if self.interrupt or self.parent_handle:
                self.poller = ParentPollerWindows(self.interrupt,
                                                  self.parent_handle)
        elif self.parent_handle:
            self.poller = ParentPollerUnix()

    def _bind_socket(self, s, port):
        iface = '%s://%s' % (self.transport, self.ip)
        if self.transport == 'tcp':
            if port <= 0:
                port = s.bind_to_random_port(iface)
            else:
                s.bind("tcp://%s:%i" % (self.ip, port))
        elif self.transport == 'ipc':
            if port <= 0:
                port = 1
                path = "%s-%i" % (self.ip, port)
                while os.path.exists(path):
                    port = port + 1
                    path = "%s-%i" % (self.ip, port)
            else:
                path = "%s-%i" % (self.ip, port)
            s.bind("ipc://%s" % path)
        return port

    def write_connection_file(self):
        """write connection info to JSON file"""
        cf = self.abs_connection_file
        self.log.debug("Writing connection file: %s", cf)
        write_connection_file(cf,
                              ip=self.ip,
                              key=self.session.key,
                              transport=self.transport,
                              shell_port=self.shell_port,
                              stdin_port=self.stdin_port,
                              hb_port=self.hb_port,
                              iopub_port=self.iopub_port,
                              control_port=self.control_port)

    def cleanup_connection_file(self):
        cf = self.abs_connection_file
        self.log.debug("Cleaning up connection file: %s", cf)
        try:
            os.remove(cf)
        except (IOError, OSError):
            pass

        self.cleanup_ipc_files()

    def init_connection_file(self):
        if not self.connection_file:
            self.connection_file = "kernel-%s.json" % os.getpid()
        try:
            self.connection_file = filefind(self.connection_file,
                                            ['.', self.connection_dir])
        except IOError:
            self.log.debug("Connection file not found: %s",
                           self.connection_file)
            # This means I own it, and I'll create it in this directory:
            ensure_dir_exists(os.path.dirname(self.abs_connection_file), 0o700)
            # Also, I will clean it up:
            atexit.register(self.cleanup_connection_file)
            return
        try:
            self.load_connection_file()
        except Exception:
            self.log.error("Failed to load connection file: %r",
                           self.connection_file,
                           exc_info=True)
            self.exit(1)

    def init_sockets(self):
        # Create a context, a session, and the kernel sockets.
        self.log.info("Starting the kernel at pid: %i", os.getpid())
        context = zmq.Context.instance()
        # Uncomment this to try closing the context.
        # atexit.register(context.term)

        self.shell_socket = context.socket(zmq.ROUTER)
        self.shell_socket.linger = 1000
        self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
        self.log.debug("shell ROUTER Channel on port: %i" % self.shell_port)

        self.iopub_socket = context.socket(zmq.PUB)
        self.iopub_socket.linger = 1000
        self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port)
        self.log.debug("iopub PUB Channel on port: %i" % self.iopub_port)

        self.stdin_socket = context.socket(zmq.ROUTER)
        self.stdin_socket.linger = 1000
        self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port)
        self.log.debug("stdin ROUTER Channel on port: %i" % self.stdin_port)

        self.control_socket = context.socket(zmq.ROUTER)
        self.control_socket.linger = 1000
        self.control_port = self._bind_socket(self.control_socket,
                                              self.control_port)
        self.log.debug("control ROUTER Channel on port: %i" %
                       self.control_port)

    def init_heartbeat(self):
        """start the heart beating"""
        # heartbeat doesn't share context, because it mustn't be blocked
        # by the GIL, which is accessed by libzmq when freeing zero-copy messages
        hb_ctx = zmq.Context()
        self.heartbeat = Heartbeat(hb_ctx,
                                   (self.transport, self.ip, self.hb_port))
        self.hb_port = self.heartbeat.port
        self.log.debug("Heartbeat REP Channel on port: %i" % self.hb_port)
        self.heartbeat.start()

    def log_connection_info(self):
        """display connection info, and store ports"""
        basename = os.path.basename(self.connection_file)
        if basename == self.connection_file or \
            os.path.dirname(self.connection_file) == self.connection_dir:
            # use shortname
            tail = basename
        else:
            tail = self.connection_file
        lines = [
            "To connect another client to this kernel, use:",
            "    --existing %s" % tail,
        ]
        # log connection info
        # info-level, so often not shown.
        # frontends should use the %connect_info magic
        # to see the connection info
        for line in lines:
            self.log.info(line)
        # also raw print to the terminal if no parent_handle (`ipython kernel`)
        if not self.parent_handle:
            io.rprint(_ctrl_c_message)
            for line in lines:
                io.rprint(line)

        self.ports = dict(shell=self.shell_port,
                          iopub=self.iopub_port,
                          stdin=self.stdin_port,
                          hb=self.hb_port,
                          control=self.control_port)

    def init_blackhole(self):
        """redirects stdout/stderr to devnull if necessary"""
        if self.no_stdout or self.no_stderr:
            blackhole = open(os.devnull, 'w')
            if self.no_stdout:
                sys.stdout = sys.__stdout__ = blackhole
            if self.no_stderr:
                sys.stderr = sys.__stderr__ = blackhole

    def init_io(self):
        """Redirect input streams and set a display hook."""
        if self.outstream_class:
            outstream_factory = import_item(str(self.outstream_class))
            sys.stdout = outstream_factory(self.session, self.iopub_socket,
                                           u'stdout')
            sys.stderr = outstream_factory(self.session, self.iopub_socket,
                                           u'stderr')
        if self.displayhook_class:
            displayhook_factory = import_item(str(self.displayhook_class))
            sys.displayhook = displayhook_factory(self.session,
                                                  self.iopub_socket)

    def init_signal(self):
        signal.signal(signal.SIGINT, signal.SIG_IGN)

    def init_kernel(self):
        """Create the Kernel object itself"""
        shell_stream = ZMQStream(self.shell_socket)
        control_stream = ZMQStream(self.control_socket)

        kernel_factory = self.kernel_class.instance

        kernel = kernel_factory(
            parent=self,
            session=self.session,
            shell_streams=[shell_stream, control_stream],
            iopub_socket=self.iopub_socket,
            stdin_socket=self.stdin_socket,
            log=self.log,
            profile_dir=self.profile_dir,
            user_ns=self.user_ns,
        )
        kernel.record_ports(self.ports)
        self.kernel = kernel

    def init_gui_pylab(self):
        """Enable GUI event loop integration, taking pylab into account."""

        # Provide a wrapper for :meth:`InteractiveShellApp.init_gui_pylab`
        # to ensure that any exception is printed straight to stderr.
        # Normally _showtraceback associates the reply with an execution,
        # which means frontends will never draw it, as this exception
        # is not associated with any execute request.

        shell = self.shell
        _showtraceback = shell._showtraceback
        try:
            # replace error-sending traceback with stderr
            def print_tb(etype, evalue, stb):
                print("GUI event loop or pylab initialization failed",
                      file=io.stderr)
                print(shell.InteractiveTB.stb2text(stb), file=io.stderr)

            shell._showtraceback = print_tb
            InteractiveShellApp.init_gui_pylab(self)
        finally:
            shell._showtraceback = _showtraceback

    def init_shell(self):
        self.shell = getattr(self.kernel, 'shell', None)
        if self.shell:
            self.shell.configurables.append(self)

    def init_extensions(self):
        super(IPKernelApp, self).init_extensions()
        # BEGIN HARDCODED WIDGETS HACK
        # Ensure ipywidgets extension is loaded if available
        extension_man = self.shell.extension_manager
        if 'ipywidgets' not in extension_man.loaded:
            try:
                extension_man.load_extension('ipywidgets')
            except ImportError as e:
                self.log.debug(
                    'ipywidgets package not installed.  Widgets will not be available.'
                )
        # END HARDCODED WIDGETS HACK

    @catch_config_error
    def initialize(self, argv=None):
        super(IPKernelApp, self).initialize(argv)
        if self.subapp is not None:
            return
        self.init_blackhole()
        self.init_connection_file()
        self.init_poller()
        self.init_sockets()
        self.init_heartbeat()
        # writing/displaying connection info must be *after* init_sockets/heartbeat
        self.log_connection_info()
        self.write_connection_file()
        self.init_io()
        self.init_signal()
        self.init_kernel()
        # shell init steps
        self.init_path()
        self.init_shell()
        if self.shell:
            self.init_gui_pylab()
            self.init_extensions()
            self.init_code()
        # flush stdout/stderr, so that anything written to these streams during
        # initialization do not get associated with the first execution request
        sys.stdout.flush()
        sys.stderr.flush()

    def start(self):
        if self.subapp is not None:
            return self.subapp.start()

        if self.poller is not None:
            self.poller.start()
        self.kernel.start()
        try:
            ioloop.IOLoop.instance().start()
        except KeyboardInterrupt:
            pass
Exemplo n.º 29
0
class Layer(Widget, InteractMixin):
    _view_name = Unicode('LeafletLayerView').tag(sync=True)
    _model_name = Unicode('LeafletLayerModel').tag(sync=True)
    _view_module = Unicode('jupyter-leaflet').tag(sync=True)
    _model_module = Unicode('jupyter-leaflet').tag(sync=True)

    _view_module_version = Unicode(EXTENSION_VERSION).tag(sync=True)
    _model_module_version = Unicode(EXTENSION_VERSION).tag(sync=True)

    name = Unicode('').tag(sync=True)
    base = Bool(False).tag(sync=True)
    bottom = Bool(False).tag(sync=True)
    popup = Instance(Widget, allow_none=True,
                     default_value=None).tag(sync=True, **widget_serialization)
    popup_min_width = Int(50).tag(sync=True)
    popup_max_width = Int(300).tag(sync=True)
    popup_max_height = Int(default_value=None, allow_none=True).tag(sync=True)

    options = List(trait=Unicode()).tag(sync=True)

    def __init__(self, **kwargs):
        super(Layer, self).__init__(**kwargs)
        self.on_msg(self._handle_mouse_events)

    @default('options')
    def _default_options(self):
        return [name for name in self.traits(o=True)]

    # Event handling
    _click_callbacks = Instance(CallbackDispatcher, ())
    _dblclick_callbacks = Instance(CallbackDispatcher, ())
    _mousedown_callbacks = Instance(CallbackDispatcher, ())
    _mouseup_callbacks = Instance(CallbackDispatcher, ())
    _mouseover_callbacks = Instance(CallbackDispatcher, ())
    _mouseout_callbacks = Instance(CallbackDispatcher, ())

    def _handle_mouse_events(self, _, content, buffers):
        event_type = content.get('type', '')
        if event_type == 'click':
            self._click_callbacks(**content)
        if event_type == 'dblclick':
            self._dblclick_callbacks(**content)
        if event_type == 'mousedown':
            self._mousedown_callbacks(**content)
        if event_type == 'mouseup':
            self._mouseup_callbacks(**content)
        if event_type == 'mouseover':
            self._mouseover_callbacks(**content)
        if event_type == 'mouseout':
            self._mouseout_callbacks(**content)

    def on_click(self, callback, remove=False):
        self._click_callbacks.register_callback(callback, remove=remove)

    def on_dblclick(self, callback, remove=False):
        self._dblclick_callbacks.register_callback(callback, remove=remove)

    def on_mousedown(self, callback, remove=False):
        self._mousedown_callbacks.register_callback(callback, remove=remove)

    def on_mouseup(self, callback, remove=False):
        self._mouseup_callbacks.register_callback(callback, remove=remove)

    def on_mouseover(self, callback, remove=False):
        self._mouseover_callbacks.register_callback(callback, remove=remove)

    def on_mouseout(self, callback, remove=False):
        self._mouseout_callbacks.register_callback(callback, remove=remove)
Exemplo n.º 30
0
class LocalProcessSpawner(Spawner):
    """
    A Spawner that uses `subprocess.Popen` to start single-user servers as local processes.

    Requires local UNIX users matching the authenticated users to exist.
    Does not work on Windows.

    This is the default spawner for JupyterHub.
    """

    INTERRUPT_TIMEOUT = Integer(10,
                                help="""
        Seconds to wait for single-user server process to halt after SIGINT.

        If the process has not exited cleanly after this many seconds, a SIGTERM is sent.
        """).tag(config=True)

    TERM_TIMEOUT = Integer(5,
                           help="""
        Seconds to wait for single-user server process to halt after SIGTERM.

        If the process does not exit cleanly after this many seconds of SIGTERM, a SIGKILL is sent.
        """).tag(config=True)

    KILL_TIMEOUT = Integer(5,
                           help="""
        Seconds to wait for process to halt after SIGKILL before giving up.

        If the process does not exit cleanly after this many seconds of SIGKILL, it becomes a zombie
        process. The hub process will log a warning and then give up.
        """).tag(config=True)

    proc = Instance(Popen,
                    allow_none=True,
                    help="""
        The process representing the single-user server process spawned for current user.

        Is None if no process has been spawned yet.
        """)
    pid = Integer(0,
                  help="""
        The process id (pid) of the single-user server process spawned for current user.
        """)

    def make_preexec_fn(self, name):
        """
        Return a function that can be used to set the user id of the spawned process to user with name `name`

        This function can be safely passed to `preexec_fn` of `Popen`
        """
        return set_user_setuid(name)

    def load_state(self, state):
        """Restore state about spawned single-user server after a hub restart.

        Local processes only need the process id.
        """
        super(LocalProcessSpawner, self).load_state(state)
        if 'pid' in state:
            self.pid = state['pid']

    def get_state(self):
        """Save state that is needed to restore this spawner instance after a hub restore.

        Local processes only need the process id.
        """
        state = super(LocalProcessSpawner, self).get_state()
        if self.pid:
            state['pid'] = self.pid
        return state

    def clear_state(self):
        """Clear stored state about this spawner (pid)"""
        super(LocalProcessSpawner, self).clear_state()
        self.pid = 0

    def user_env(self, env):
        """Augment environment of spawned process with user specific env variables."""
        env['USER'] = self.user.name
        home = pwd.getpwnam(self.user.name).pw_dir
        shell = pwd.getpwnam(self.user.name).pw_shell
        # These will be empty if undefined,
        # in which case don't set the env:
        if home:
            env['HOME'] = home
        if shell:
            env['SHELL'] = shell
        return env

    def get_env(self):
        """Get the complete set of environment variables to be set in the spawned process."""
        env = super().get_env()
        env = self.user_env(env)
        return env

    @gen.coroutine
    def start(self):
        """Start the single-user server."""
        self.port = random_port()
        cmd = []
        env = self.get_env()

        cmd.extend(self.cmd)
        cmd.extend(self.get_args())

        self.log.info("Spawning %s", ' '.join(pipes.quote(s) for s in cmd))
        try:
            self.proc = Popen(
                cmd,
                env=env,
                preexec_fn=self.make_preexec_fn(self.user.name),
                start_new_session=True,  # don't forward signals
            )
        except PermissionError:
            # use which to get abspath
            script = shutil.which(cmd[0]) or cmd[0]
            self.log.error(
                "Permission denied trying to run %r. Does %s have access to this file?",
                script,
                self.user.name,
            )
            raise

        self.pid = self.proc.pid

        if self.__class__ is not LocalProcessSpawner:
            # subclasses may not pass through return value of super().start,
            # relying on deprecated 0.6 way of setting ip, port,
            # so keep a redundant copy here for now.
            # A deprecation warning will be shown if the subclass
            # does not return ip, port.
            if self.ip:
                self.user.server.ip = self.ip
            self.user.server.port = self.port
        return (self.ip or '127.0.0.1', self.port)

    @gen.coroutine
    def poll(self):
        """Poll the spawned process to see if it is still running.

        If the process is still running, we return None. If it is not running,
        we return the exit code of the process if we have access to it, or 0 otherwise.
        """
        # if we started the process, poll with Popen
        if self.proc is not None:
            status = self.proc.poll()
            if status is not None:
                # clear state if the process is done
                self.clear_state()
            return status

        # if we resumed from stored state,
        # we don't have the Popen handle anymore, so rely on self.pid
        if not self.pid:
            # no pid, not running
            self.clear_state()
            return 0

        # send signal 0 to check if PID exists
        # this doesn't work on Windows, but that's okay because we don't support Windows.
        alive = yield self._signal(0)
        if not alive:
            self.clear_state()
            return 0
        else:
            return None

    @gen.coroutine
    def _signal(self, sig):
        """Send given signal to a single-user server's process.

        Returns True if the process still exists, False otherwise.

        The hub process is assumed to have enough privileges to do this (e.g. root).
        """
        try:
            os.kill(self.pid, sig)
        except OSError as e:
            if e.errno == errno.ESRCH:
                return False  # process is gone
            else:
                raise
        return True  # process exists

    @gen.coroutine
    def stop(self, now=False):
        """Stop the single-user server process for the current user.

        If `now` is set to True, do not wait for the process to die.
        Otherwise, it'll wait.
        """
        if not now:
            status = yield self.poll()
            if status is not None:
                return
            self.log.debug("Interrupting %i", self.pid)
            yield self._signal(signal.SIGINT)
            yield self.wait_for_death(self.INTERRUPT_TIMEOUT)

        # clean shutdown failed, use TERM
        status = yield self.poll()
        if status is not None:
            return
        self.log.debug("Terminating %i", self.pid)
        yield self._signal(signal.SIGTERM)
        yield self.wait_for_death(self.TERM_TIMEOUT)

        # TERM failed, use KILL
        status = yield self.poll()
        if status is not None:
            return
        self.log.debug("Killing %i", self.pid)
        yield self._signal(signal.SIGKILL)
        yield self.wait_for_death(self.KILL_TIMEOUT)

        status = yield self.poll()
        if status is None:
            # it all failed, zombie process
            self.log.warning("Process %i never died", self.pid)