示例#1
0
class BaseConverter(LoggingConfigurable):

    notebooks = List([])
    assignments = Dict({})
    writer = Instance(FilesWriter)
    exporter = Instance(Exporter)
    exporter_class = Type(NotebookExporter, klass=Exporter)
    preprocessors = List([])

    force = Bool(
        False,
        help="Whether to overwrite existing assignments/submissions").tag(
            config=True)

    permissions = Integer(help=dedent("""
            Permissions to set on files output by nbgrader. The default is
            generally read-only (444), with the exception of nbgrader
            generate_assignment and nbgrader generate_feedback, in which case
            the user also has write permission.
            """)).tag(config=True)

    @default("permissions")
    def _permissions_default(self):
        return 664 if self.coursedir.groupshared else 444

    coursedir = Instance(CourseDirectory, allow_none=True)

    def __init__(self, coursedir=None, **kwargs):
        self.coursedir = coursedir
        super(BaseConverter, self).__init__(**kwargs)
        if self.parent and hasattr(self.parent, "logfile"):
            self.logfile = self.parent.logfile
        else:
            self.logfile = None

        c = Config()
        c.Exporter.default_preprocessors = []
        self.update_config(c)

    def start(self):
        self.init_notebooks()
        self.writer = FilesWriter(parent=self, config=self.config)
        self.exporter = self.exporter_class(parent=self, config=self.config)
        for pp in self.preprocessors:
            self.exporter.register_preprocessor(pp)
        currdir = os.getcwd()
        os.chdir(self.coursedir.root)
        try:
            self.convert_notebooks()
        finally:
            os.chdir(currdir)

    @default("classes")
    def _classes_default(self):
        classes = super(BaseConverter, self)._classes_default()
        classes.append(FilesWriter)
        classes.append(Exporter)
        for pp in self.preprocessors:
            if len(pp.class_traits(config=True)) > 0:
                classes.append(pp)
        return classes

    @property
    def _input_directory(self):
        raise NotImplementedError

    @property
    def _output_directory(self):
        raise NotImplementedError

    def _format_source(self, assignment_id, student_id, escape=False):
        return self.coursedir.format_path(self._input_directory,
                                          student_id,
                                          assignment_id,
                                          escape=escape)

    def _format_dest(self, assignment_id, student_id, escape=False):
        return self.coursedir.format_path(self._output_directory,
                                          student_id,
                                          assignment_id,
                                          escape=escape)

    def init_notebooks(self):
        self.assignments = {}
        self.notebooks = []
        assignment_glob = self._format_source(self.coursedir.assignment_id,
                                              self.coursedir.student_id)
        for assignment in glob.glob(assignment_glob):
            notebook_glob = os.path.join(assignment,
                                         self.coursedir.notebook_id + ".ipynb")
            found = glob.glob(notebook_glob)
            if len(found) == 0:
                self.log.warning("No notebooks were matched by '%s'",
                                 notebook_glob)
                continue
            self.assignments[assignment] = found

        if len(self.assignments) == 0:
            msg = "No notebooks were matched by '%s'" % assignment_glob
            self.log.error(msg)

            assignment_glob2 = self._format_source("*",
                                                   self.coursedir.student_id)
            found = glob.glob(assignment_glob2)
            if found:
                # Normally it is a bad idea to put imports in the middle of
                # a function, but we do this here because otherwise fuzzywuzzy
                # prints an annoying message about python-Levenshtein every
                # time nbgrader is run.
                from fuzzywuzzy import fuzz
                scores = sorted([(fuzz.ratio(assignment_glob, x), x)
                                 for x in found])
                self.log.error("Did you mean: %s", scores[-1][1])

            raise NbGraderException(msg)

    def init_single_notebook_resources(self, notebook_filename):
        regexp = re.escape(os.path.sep).join([
            self._format_source("(?P<assignment_id>.*)",
                                "(?P<student_id>.*)",
                                escape=True), "(?P<notebook_id>.*).ipynb"
        ])

        m = re.match(regexp, notebook_filename)
        if m is None:
            msg = "Could not match '%s' with regexp '%s'" % (notebook_filename,
                                                             regexp)
            self.log.error(msg)
            raise NbGraderException(msg)

        gd = m.groupdict()

        self.log.debug("Student: %s", gd['student_id'])
        self.log.debug("Assignment: %s", gd['assignment_id'])
        self.log.debug("Notebook: %s", gd['notebook_id'])

        resources = {}
        resources['unique_key'] = gd['notebook_id']
        resources['output_files_dir'] = '%s_files' % gd['notebook_id']

        resources['nbgrader'] = {}
        resources['nbgrader']['student'] = gd['student_id']
        resources['nbgrader']['assignment'] = gd['assignment_id']
        resources['nbgrader']['notebook'] = gd['notebook_id']
        resources['nbgrader']['db_url'] = self.coursedir.db_url

        return resources

    def write_single_notebook(self, output, resources):
        # configure the writer build directory
        self.writer.build_directory = self._format_dest(
            resources['nbgrader']['assignment'],
            resources['nbgrader']['student'])

        # write out the results
        self.writer.write(output,
                          resources,
                          notebook_name=resources['unique_key'])

    def init_destination(self, assignment_id, student_id):
        """Initialize the destination for an assignment. Returns whether the
        assignment should actually be processed or not (i.e. whether the
        initialization was successful).

        """
        if self.coursedir.student_id_exclude:
            exclude_ids = self.coursedir.student_id_exclude.split(',')
            if student_id in exclude_ids:
                return False

        dest = os.path.normpath(self._format_dest(assignment_id, student_id))

        # the destination doesn't exist, so we haven't processed it
        if self.coursedir.notebook_id == "*":
            if not os.path.exists(dest):
                return True
        else:
            # if any of the notebooks don't exist, then we want to process them
            for notebook in self.notebooks:
                filename = os.path.splitext(os.path.basename(
                    notebook))[0] + self.exporter.file_extension
                path = os.path.join(dest, filename)
                if not os.path.exists(path):
                    return True

        # if we have specified --force, then always remove existing stuff
        if self.force:
            if self.coursedir.notebook_id == "*":
                self.log.warning(
                    "Removing existing assignment: {}".format(dest))
                rmtree(dest)
            else:
                for notebook in self.notebooks:
                    filename = os.path.splitext(os.path.basename(
                        notebook))[0] + self.exporter.file_extension
                    path = os.path.join(dest, filename)
                    if os.path.exists(path):
                        self.log.warning(
                            "Removing existing notebook: {}".format(path))
                        remove(path)
            return True

        src = self._format_source(assignment_id, student_id)
        new_timestamp = self.coursedir.get_existing_timestamp(src)
        old_timestamp = self.coursedir.get_existing_timestamp(dest)

        # if --force hasn't been specified, but the source assignment is newer,
        # then we want to overwrite it
        if new_timestamp is not None and old_timestamp is not None and new_timestamp > old_timestamp:
            if self.coursedir.notebook_id == "*":
                self.log.warning(
                    "Updating existing assignment: {}".format(dest))
                rmtree(dest)
            else:
                for notebook in self.notebooks:
                    filename = os.path.splitext(os.path.basename(
                        notebook))[0] + self.exporter.file_extension
                    path = os.path.join(dest, filename)
                    if os.path.exists(path):
                        self.log.warning(
                            "Updating existing notebook: {}".format(path))
                        remove(path)
            return True

        # otherwise, we should skip the assignment
        self.log.info("Skipping existing assignment: {}".format(dest))
        return False

    def init_assignment(self, assignment_id, student_id):
        """Initializes resources/dependencies/etc. that are common to all
        notebooks in an assignment.

        """
        source = self._format_source(assignment_id, student_id)
        dest = self._format_dest(assignment_id, student_id)

        # detect other files in the source directory
        for filename in find_all_files(source,
                                       self.coursedir.ignore + ["*.ipynb"]):
            # Make sure folder exists.
            path = os.path.join(dest, os.path.relpath(filename, source))
            if not os.path.exists(os.path.dirname(path)):
                os.makedirs(os.path.dirname(path))
            if os.path.exists(path):
                remove(path)
            self.log.info("Copying %s -> %s", filename, path)
            shutil.copy(filename, path)

    def set_permissions(self, assignment_id, student_id):
        self.log.info("Setting destination file permissions to %s",
                      self.permissions)
        dest = os.path.normpath(self._format_dest(assignment_id, student_id))
        permissions = int(str(self.permissions), 8)
        for dirname, _, filenames in os.walk(dest):
            for filename in filenames:
                os.chmod(os.path.join(dirname, filename), permissions)
            # If groupshared, set dir permissions - see comment below.
            st_mode = os.stat(dirname).st_mode
            if self.coursedir.groupshared and st_mode & 0o2770 != 0o2770:
                try:
                    os.chmod(dirname, (st_mode | 0o2770) & 0o2777)
                except PermissionError:
                    self.log.warning(
                        "Could not update permissions of %s to make it groupshared",
                        dirname)
        # If groupshared, set write permissions on directories.  Directories
        # are created within ipython_genutils.path.ensure_dir_exists via
        # nbconvert.writer, (unless there are supplementary files) with a
        # default mode of 755 and there is no way to pass the mode arguments
        # all the way to there!  So we have to walk and fix.
        if self.coursedir.groupshared:
            # Root may be created in this step, and is not included above.
            rootdir = self.coursedir.format_path(self._output_directory, '.',
                                                 '.')
            # Add 2770 to existing dir permissions (don't unconditionally override)
            st_mode = os.stat(rootdir).st_mode
            if st_mode & 0o2770 != 0o2770:
                try:
                    os.chmod(rootdir, (st_mode | 0o2770) & 0o2777)
                except PermissionError:
                    self.log.warning(
                        "Could not update permissions of %s to make it groupshared",
                        rootdir)

    def convert_single_notebook(self, notebook_filename):
        """Convert a single notebook.

        Performs the following steps:
            1. Initialize notebook resources
            2. Export the notebook to a particular format
            3. Write the exported notebook to file

        """
        self.log.info("Converting notebook %s", notebook_filename)
        resources = self.init_single_notebook_resources(notebook_filename)
        output, resources = self.exporter.from_filename(notebook_filename,
                                                        resources=resources)
        self.write_single_notebook(output, resources)

    def convert_notebooks(self):
        errors = []

        def _handle_failure(gd):
            dest = os.path.normpath(
                self._format_dest(gd['assignment_id'], gd['student_id']))
            if self.coursedir.notebook_id == "*":
                if os.path.exists(dest):
                    self.log.warning(
                        "Removing failed assignment: {}".format(dest))
                    rmtree(dest)
            else:
                for notebook in self.notebooks:
                    filename = os.path.splitext(os.path.basename(
                        notebook))[0] + self.exporter.file_extension
                    path = os.path.join(dest, filename)
                    if os.path.exists(path):
                        self.log.warning(
                            "Removing failed notebook: {}".format(path))
                        remove(path)

        for assignment in sorted(self.assignments.keys()):
            # initialize the list of notebooks and the exporter
            self.notebooks = sorted(self.assignments[assignment])

            # parse out the assignment and student ids
            regexp = self._format_source("(?P<assignment_id>.*)",
                                         "(?P<student_id>.*)",
                                         escape=True)
            m = re.match(regexp, assignment)
            if m is None:
                msg = "Could not match '%s' with regexp '%s'" % (assignment,
                                                                 regexp)
                self.log.error(msg)
                raise NbGraderException(msg)
            gd = m.groupdict()

            try:
                # determine whether we actually even want to process this submission
                should_process = self.init_destination(gd['assignment_id'],
                                                       gd['student_id'])
                if not should_process:
                    continue

                # initialize the destination
                self.init_assignment(gd['assignment_id'], gd['student_id'])

                # convert all the notebooks
                for notebook_filename in self.notebooks:
                    self.convert_single_notebook(notebook_filename)

                # set assignment permissions
                self.set_permissions(gd['assignment_id'], gd['student_id'])

            except UnresponsiveKernelError:
                self.log.error(
                    "While processing assignment %s, the kernel became "
                    "unresponsive and we could not interrupt it. This probably "
                    "means that the students' code has an infinite loop that "
                    "consumes a lot of memory or something similar. nbgrader "
                    "doesn't know how to deal with this problem, so you will "
                    "have to manually edit the students' code (for example, to "
                    "just throw an error rather than enter an infinite loop). ",
                    assignment)
                errors.append((gd['assignment_id'], gd['student_id']))
                _handle_failure(gd)

            except sqlalchemy.exc.OperationalError:
                _handle_failure(gd)
                self.log.error(traceback.format_exc())
                msg = (
                    "There was an error accessing the nbgrader database. This "
                    "may occur if you recently upgraded nbgrader. To resolve "
                    "the issue, first BACK UP your database and then run the "
                    "command `nbgrader db upgrade`.")
                self.log.error(msg)
                raise NbGraderException(msg)

            except SchemaTooOldError:
                _handle_failure(gd)
                msg = (
                    "One or more notebooks in the assignment use an old version \n"
                    "of the nbgrader metadata format. Please **back up your class files \n"
                    "directory** and then update the metadata using:\n\nnbgrader update .\n"
                )
                self.log.error(msg)
                raise NbGraderException(msg)

            except SchemaTooNewError:
                _handle_failure(gd)
                msg = (
                    "One or more notebooks in the assignment use an newer version \n"
                    "of the nbgrader metadata format. Please update your version of \n"
                    "nbgrader to the latest version to be able to use this notebook.\n"
                )
                self.log.error(msg)
                raise NbGraderException(msg)

            except KeyboardInterrupt:
                _handle_failure(gd)
                self.log.error("Canceled")
                raise

            except Exception:
                self.log.error("There was an error processing assignment: %s",
                               assignment)
                self.log.error(traceback.format_exc())
                errors.append((gd['assignment_id'], gd['student_id']))
                _handle_failure(gd)

        if len(errors) > 0:
            for assignment_id, student_id in errors:
                self.log.error(
                    "There was an error processing assignment '{}' for student '{}'"
                    .format(assignment_id, student_id))

            if self.logfile:
                msg = (
                    "Please see the error log ({}) for details on the specific "
                    "errors on the above failures.".format(self.logfile))
            else:
                msg = (
                    "Please see the the above traceback for details on the specific "
                    "errors on the above failures.")

            self.log.error(msg)
            raise NbGraderException(msg)
示例#2
0
class BaseIPythonApplication(Application):

    name = Unicode(u'ipython')
    description = Unicode(u'IPython: an enhanced interactive Python shell.')
    version = Unicode(release.version)

    aliases = Dict(base_aliases)
    flags = Dict(base_flags)
    classes = List([ProfileDir])

    # enable `load_subconfig('cfg.py', profile='name')`
    python_config_loader_class = ProfileAwareConfigLoader

    # Track whether the config_file has changed,
    # because some logic happens only if we aren't using the default.
    config_file_specified = Set()

    config_file_name = Unicode()

    def _config_file_name_default(self):
        return self.name.replace('-', '_') + u'_config.py'

    def _config_file_name_changed(self, name, old, new):
        if new != old:
            self.config_file_specified.add(new)

    # The directory that contains IPython's builtin profiles.
    builtin_profile_dir = Unicode(
        os.path.join(get_ipython_package_dir(), u'config', u'profile',
                     u'default'))

    config_file_paths = List(Unicode())

    def _config_file_paths_default(self):
        return [py3compat.getcwd()]

    extra_config_file = Unicode(config=True,
                                help="""Path to an extra config file to load.
    
    If specified, load this config file in addition to any other IPython config.
    """)

    def _extra_config_file_changed(self, name, old, new):
        try:
            self.config_files.remove(old)
        except ValueError:
            pass
        self.config_file_specified.add(new)
        self.config_files.append(new)

    profile = Unicode(u'default',
                      config=True,
                      help="""The IPython profile to use.""")

    def _profile_changed(self, name, old, new):
        self.builtin_profile_dir = os.path.join(get_ipython_package_dir(),
                                                u'config', u'profile', new)

    ipython_dir = Unicode(config=True,
                          help="""
        The name of the IPython directory. This directory is used for logging
        configuration (through profiles), history storage, etc. The default
        is usually $HOME/.ipython. This option can also be specified through
        the environment variable IPYTHONDIR.
        """)

    def _ipython_dir_default(self):
        d = get_ipython_dir()
        self._ipython_dir_changed('ipython_dir', d, d)
        return d

    _in_init_profile_dir = False
    profile_dir = Instance(ProfileDir, allow_none=True)

    def _profile_dir_default(self):
        # avoid recursion
        if self._in_init_profile_dir:
            return
        # profile_dir requested early, force initialization
        self.init_profile_dir()
        return self.profile_dir

    overwrite = Bool(
        False,
        config=True,
        help="""Whether to overwrite existing config files when copying""")
    auto_create = Bool(
        False,
        config=True,
        help="""Whether to create profile dir if it doesn't exist""")

    config_files = List(Unicode())

    def _config_files_default(self):
        return [self.config_file_name]

    copy_config_files = Bool(
        False,
        config=True,
        help="""Whether to install the default config files into the profile dir.
        If a new profile is being created, and IPython contains config files for that
        profile, then they will be staged into the new directory.  Otherwise,
        default config files will be automatically generated.
        """)

    verbose_crash = Bool(
        False,
        config=True,
        help=
        """Create a massive crash report when IPython encounters what may be an
        internal error.  The default is to append a short message to the
        usual traceback""")

    # The class to use as the crash handler.
    crash_handler_class = Type(crashhandler.CrashHandler)

    @catch_config_error
    def __init__(self, **kwargs):
        super(BaseIPythonApplication, self).__init__(**kwargs)
        # ensure current working directory exists
        try:
            py3compat.getcwd()
        except:
            # exit if cwd doesn't exist
            self.log.error("Current working directory doesn't exist.")
            self.exit(1)

    #-------------------------------------------------------------------------
    # Various stages of Application creation
    #-------------------------------------------------------------------------

    deprecated_subcommands = {}

    def initialize_subcommand(self, subc, argv=None):
        if subc in self.deprecated_subcommands:
            import time
            self.log.warning(
                "Subcommand `ipython {sub}` is deprecated and will be removed "
                "in future versions.".format(sub=subc))
            self.log.warning("You likely want to use `jupyter {sub}` in the "
                             "future".format(sub=subc))
        return super(BaseIPythonApplication,
                     self).initialize_subcommand(subc, argv)

    def init_crash_handler(self):
        """Create a crash handler, typically setting sys.excepthook to it."""
        self.crash_handler = self.crash_handler_class(self)
        sys.excepthook = self.excepthook

        def unset_crashhandler():
            sys.excepthook = sys.__excepthook__

        atexit.register(unset_crashhandler)

    def excepthook(self, etype, evalue, tb):
        """this is sys.excepthook after init_crashhandler
        
        set self.verbose_crash=True to use our full crashhandler, instead of
        a regular traceback with a short message (crash_handler_lite)
        """

        if self.verbose_crash:
            return self.crash_handler(etype, evalue, tb)
        else:
            return crashhandler.crash_handler_lite(etype, evalue, tb)

    def _ipython_dir_changed(self, name, old, new):
        if old is not Undefined:
            str_old = py3compat.cast_bytes_py2(os.path.abspath(old),
                                               sys.getfilesystemencoding())
            if str_old in sys.path:
                sys.path.remove(str_old)
        str_path = py3compat.cast_bytes_py2(os.path.abspath(new),
                                            sys.getfilesystemencoding())
        sys.path.append(str_path)
        ensure_dir_exists(new)
        readme = os.path.join(new, 'README')
        readme_src = os.path.join(get_ipython_package_dir(), u'config',
                                  u'profile', 'README')
        if not os.path.exists(readme) and os.path.exists(readme_src):
            shutil.copy(readme_src, readme)
        for d in ('extensions', 'nbextensions'):
            path = os.path.join(new, d)
            try:
                ensure_dir_exists(path)
            except OSError as e:
                # this will not be EEXIST
                self.log.error("couldn't create path %s: %s", path, e)
        self.log.debug("IPYTHONDIR set to: %s" % new)

    def load_config_file(self, suppress_errors=True):
        """Load the config file.

        By default, errors in loading config are handled, and a warning
        printed on screen. For testing, the suppress_errors option is set
        to False, so errors will make tests fail.
        """
        self.log.debug("Searching path %s for config files",
                       self.config_file_paths)
        base_config = 'ipython_config.py'
        self.log.debug("Attempting to load config file: %s" % base_config)
        try:
            Application.load_config_file(self,
                                         base_config,
                                         path=self.config_file_paths)
        except ConfigFileNotFound:
            # ignore errors loading parent
            self.log.debug("Config file %s not found", base_config)
            pass

        for config_file_name in self.config_files:
            if not config_file_name or config_file_name == base_config:
                continue
            self.log.debug("Attempting to load config file: %s" %
                           self.config_file_name)
            try:
                Application.load_config_file(self,
                                             config_file_name,
                                             path=self.config_file_paths)
            except ConfigFileNotFound:
                # Only warn if the default config file was NOT being used.
                if config_file_name in self.config_file_specified:
                    msg = self.log.warning
                else:
                    msg = self.log.debug
                msg("Config file not found, skipping: %s", config_file_name)
            except Exception:
                # For testing purposes.
                if not suppress_errors:
                    raise
                self.log.warning("Error loading config file: %s" %
                                 self.config_file_name,
                                 exc_info=True)

    def init_profile_dir(self):
        """initialize the profile dir"""
        self._in_init_profile_dir = True
        if self.profile_dir is not None:
            # already ran
            return
        if 'ProfileDir.location' not in self.config:
            # location not specified, find by profile name
            try:
                p = ProfileDir.find_profile_dir_by_name(
                    self.ipython_dir, self.profile, self.config)
            except ProfileDirError:
                # not found, maybe create it (always create default profile)
                if self.auto_create or self.profile == 'default':
                    try:
                        p = ProfileDir.create_profile_dir_by_name(
                            self.ipython_dir, self.profile, self.config)
                    except ProfileDirError:
                        self.log.fatal("Could not create profile: %r" %
                                       self.profile)
                        self.exit(1)
                    else:
                        self.log.info("Created profile dir: %r" % p.location)
                else:
                    self.log.fatal("Profile %r not found." % self.profile)
                    self.exit(1)
            else:
                self.log.debug("Using existing profile dir: %r" % p.location)
        else:
            location = self.config.ProfileDir.location
            # location is fully specified
            try:
                p = ProfileDir.find_profile_dir(location, self.config)
            except ProfileDirError:
                # not found, maybe create it
                if self.auto_create:
                    try:
                        p = ProfileDir.create_profile_dir(
                            location, self.config)
                    except ProfileDirError:
                        self.log.fatal(
                            "Could not create profile directory: %r" %
                            location)
                        self.exit(1)
                    else:
                        self.log.debug("Creating new profile dir: %r" %
                                       location)
                else:
                    self.log.fatal("Profile directory %r not found." %
                                   location)
                    self.exit(1)
            else:
                self.log.info("Using existing profile dir: %r" % location)
            # if profile_dir is specified explicitly, set profile name
            dir_name = os.path.basename(p.location)
            if dir_name.startswith('profile_'):
                self.profile = dir_name[8:]

        self.profile_dir = p
        self.config_file_paths.append(p.location)
        self._in_init_profile_dir = False

    def init_config_files(self):
        """[optionally] copy default config files into profile dir."""
        self.config_file_paths.extend(SYSTEM_CONFIG_DIRS)
        # copy config files
        path = self.builtin_profile_dir
        if self.copy_config_files:
            src = self.profile

            cfg = self.config_file_name
            if path and os.path.exists(os.path.join(path, cfg)):
                self.log.warning(
                    "Staging %r from %s into %r [overwrite=%s]" %
                    (cfg, src, self.profile_dir.location, self.overwrite))
                self.profile_dir.copy_config_file(cfg,
                                                  path=path,
                                                  overwrite=self.overwrite)
            else:
                self.stage_default_config_file()
        else:
            # Still stage *bundled* config files, but not generated ones
            # This is necessary for `ipython profile=sympy` to load the profile
            # on the first go
            files = glob.glob(os.path.join(path, '*.py'))
            for fullpath in files:
                cfg = os.path.basename(fullpath)
                if self.profile_dir.copy_config_file(cfg,
                                                     path=path,
                                                     overwrite=False):
                    # file was copied
                    self.log.warning(
                        "Staging bundled %s from %s into %r" %
                        (cfg, self.profile, self.profile_dir.location))

    def stage_default_config_file(self):
        """auto generate default config file, and stage it into the profile."""
        s = self.generate_config_file()
        fname = os.path.join(self.profile_dir.location, self.config_file_name)
        if self.overwrite or not os.path.exists(fname):
            self.log.warning("Generating default config file: %r" % (fname))
            with open(fname, 'w') as f:
                f.write(s)

    @catch_config_error
    def initialize(self, argv=None):
        # don't hook up crash handler before parsing command-line
        self.parse_command_line(argv)
        self.init_crash_handler()
        if self.subapp is not None:
            # stop here if subapp is taking over
            return
        cl_config = self.config
        self.init_profile_dir()
        self.init_config_files()
        self.load_config_file()
        # enforce cl-opts override configfile opts:
        self.update_config(cl_config)
示例#3
0
class KernelClient(ConnectionFileMixin):
    """Communicates with a single kernel on any host via zmq channels.

    There are five channels associated with each kernel:

    * shell: for request/reply calls to the kernel.
    * iopub: for the kernel to publish results to frontends.
    * hb: for monitoring the kernel's heartbeat.
    * stdin: for frontends to reply to raw_input calls in the kernel.
    * control: for kernel management calls to the kernel.

    The messages that can be sent on these channels are exposed as methods of the
    client (KernelClient.execute, complete, history, etc.). These methods only
    send the message, they don't wait for a reply. To get results, use e.g.
    :meth:`get_shell_msg` to fetch messages from the shell channel.
    """

    # The PyZMQ Context to use for communication with the kernel.
    context = Instance(zmq.Context)

    def _context_default(self):
        return zmq.Context()

    # The classes to use for the various channels
    shell_channel_class = Type(ChannelABC)
    iopub_channel_class = Type(ChannelABC)
    stdin_channel_class = Type(ChannelABC)
    hb_channel_class = Type(HBChannelABC)
    control_channel_class = Type(ChannelABC)

    # Protected traits
    _shell_channel = Any()
    _iopub_channel = Any()
    _stdin_channel = Any()
    _hb_channel = Any()
    _control_channel = Any()

    # flag for whether execute requests should be allowed to call raw_input:
    allow_stdin = True

    #--------------------------------------------------------------------------
    # Channel proxy methods
    #--------------------------------------------------------------------------

    def get_shell_msg(self, *args, **kwargs):
        """Get a message from the shell channel"""
        return self.shell_channel.get_msg(*args, **kwargs)

    def get_iopub_msg(self, *args, **kwargs):
        """Get a message from the iopub channel"""
        return self.iopub_channel.get_msg(*args, **kwargs)

    def get_stdin_msg(self, *args, **kwargs):
        """Get a message from the stdin channel"""
        return self.stdin_channel.get_msg(*args, **kwargs)

    def get_control_msg(self, *args, **kwargs):
        """Get a message from the control channel"""
        return self.control_channel.get_msg(*args, **kwargs)

    #--------------------------------------------------------------------------
    # Channel management methods
    #--------------------------------------------------------------------------

    def start_channels(self,
                       shell=True,
                       iopub=True,
                       stdin=True,
                       hb=True,
                       control=True):
        """Starts the channels for this kernel.

        This will create the channels if they do not exist and then start
        them (their activity runs in a thread). If port numbers of 0 are
        being used (random ports) then you must first call
        :meth:`start_kernel`. If the channels have been stopped and you
        call this, :class:`RuntimeError` will be raised.
        """
        if iopub:
            self.iopub_channel.start()
        if shell:
            self.shell_channel.start()
        if stdin:
            self.stdin_channel.start()
            self.allow_stdin = True
        else:
            self.allow_stdin = False
        if hb:
            self.hb_channel.start()
        if control:
            self.control_channel.start()

    def stop_channels(self):
        """Stops all the running channels for this kernel.

        This stops their event loops and joins their threads.
        """
        if self.shell_channel.is_alive():
            self.shell_channel.stop()
        if self.iopub_channel.is_alive():
            self.iopub_channel.stop()
        if self.stdin_channel.is_alive():
            self.stdin_channel.stop()
        if self.hb_channel.is_alive():
            self.hb_channel.stop()
        if self.control_channel.is_alive():
            self.control_channel.stop()

    @property
    def channels_running(self):
        """Are any of the channels created and running?"""
        return (self.shell_channel.is_alive() or self.iopub_channel.is_alive()
                or self.stdin_channel.is_alive() or self.hb_channel.is_alive()
                or self.control_channel.is_alive())

    ioloop = None  # Overridden in subclasses that use pyzmq event loop

    @property
    def shell_channel(self):
        """Get the shell channel object for this kernel."""
        if self._shell_channel is None:
            url = self._make_url('shell')
            self.log.debug("connecting shell channel to %s", url)
            socket = self.connect_shell(identity=self.session.bsession)
            self._shell_channel = self.shell_channel_class(
                socket, self.session, self.ioloop)
        return self._shell_channel

    @property
    def iopub_channel(self):
        """Get the iopub channel object for this kernel."""
        if self._iopub_channel is None:
            url = self._make_url('iopub')
            self.log.debug("connecting iopub channel to %s", url)
            socket = self.connect_iopub()
            self._iopub_channel = self.iopub_channel_class(
                socket, self.session, self.ioloop)
        return self._iopub_channel

    @property
    def stdin_channel(self):
        """Get the stdin channel object for this kernel."""
        if self._stdin_channel is None:
            url = self._make_url('stdin')
            self.log.debug("connecting stdin channel to %s", url)
            socket = self.connect_stdin(identity=self.session.bsession)
            self._stdin_channel = self.stdin_channel_class(
                socket, self.session, self.ioloop)
        return self._stdin_channel

    @property
    def hb_channel(self):
        """Get the hb channel object for this kernel."""
        if self._hb_channel is None:
            url = self._make_url('hb')
            self.log.debug("connecting heartbeat channel to %s", url)
            self._hb_channel = self.hb_channel_class(self.context,
                                                     self.session, url)
        return self._hb_channel

    @property
    def control_channel(self):
        """Get the control channel object for this kernel."""
        if self._control_channel is None:
            url = self._make_url('control')
            self.log.debug("connecting control channel to %s", url)
            socket = self.connect_control(identity=self.session.bsession)
            self._control_channel = self.control_channel_class(
                socket, self.session, self.ioloop)
        return self._control_channel

    def is_alive(self):
        """Is the kernel process still running?"""
        from .manager import KernelManager
        if isinstance(self.parent, KernelManager):
            # This KernelClient was created by a KernelManager,
            # we can ask the parent KernelManager:
            return self.parent.is_alive()
        if self._hb_channel is not None:
            # We don't have access to the KernelManager,
            # so we use the heartbeat.
            return self._hb_channel.is_beating()
        else:
            # no heartbeat and not local, we can't tell if it's running,
            # so naively return True
            return True

    # Methods to send specific messages on channels
    def execute(self,
                code,
                silent=False,
                store_history=True,
                user_expressions=None,
                allow_stdin=None,
                stop_on_error=True):
        """Execute code in the kernel.

        Parameters
        ----------
        code : str
            A string of code in the kernel's language.

        silent : bool, optional (default False)
            If set, the kernel will execute the code as quietly possible, and
            will force store_history to be False.

        store_history : bool, optional (default True)
            If set, the kernel will store command history.  This is forced
            to be False if silent is True.

        user_expressions : dict, optional
            A dict mapping names to expressions to be evaluated in the user's
            dict. The expression values are returned as strings formatted using
            :func:`repr`.

        allow_stdin : bool, optional (default self.allow_stdin)
            Flag for whether the kernel can send stdin requests to frontends.

            Some frontends (e.g. the Notebook) do not support stdin requests.
            If raw_input is called from code executed from such a frontend, a
            StdinNotImplementedError will be raised.

        stop_on_error: bool, optional (default True)
            Flag whether to abort the execution queue, if an exception is encountered.

        Returns
        -------
        The msg_id of the message sent.
        """
        if user_expressions is None:
            user_expressions = {}
        if allow_stdin is None:
            allow_stdin = self.allow_stdin

        # Don't waste network traffic if inputs are invalid
        if not isinstance(code, str):
            raise ValueError('code %r must be a string' % code)
        validate_string_dict(user_expressions)

        # Create class for content/msg creation. Related to, but possibly
        # not in Session.
        content = dict(code=code,
                       silent=silent,
                       store_history=store_history,
                       user_expressions=user_expressions,
                       allow_stdin=allow_stdin,
                       stop_on_error=stop_on_error)
        msg = self.session.msg('execute_request', content)
        self.shell_channel.send(msg)
        return msg['header']['msg_id']

    def complete(self, code, cursor_pos=None):
        """Tab complete text in the kernel's namespace.

        Parameters
        ----------
        code : str
            The context in which completion is requested.
            Can be anything between a variable name and an entire cell.
        cursor_pos : int, optional
            The position of the cursor in the block of code where the completion was requested.
            Default: ``len(code)``

        Returns
        -------
        The msg_id of the message sent.
        """
        if cursor_pos is None:
            cursor_pos = len(code)
        content = dict(code=code, cursor_pos=cursor_pos)
        msg = self.session.msg('complete_request', content)
        self.shell_channel.send(msg)
        return msg['header']['msg_id']

    def inspect(self, code, cursor_pos=None, detail_level=0):
        """Get metadata information about an object in the kernel's namespace.

        It is up to the kernel to determine the appropriate object to inspect.

        Parameters
        ----------
        code : str
            The context in which info is requested.
            Can be anything between a variable name and an entire cell.
        cursor_pos : int, optional
            The position of the cursor in the block of code where the info was requested.
            Default: ``len(code)``
        detail_level : int, optional
            The level of detail for the introspection (0-2)

        Returns
        -------
        The msg_id of the message sent.
        """
        if cursor_pos is None:
            cursor_pos = len(code)
        content = dict(
            code=code,
            cursor_pos=cursor_pos,
            detail_level=detail_level,
        )
        msg = self.session.msg('inspect_request', content)
        self.shell_channel.send(msg)
        return msg['header']['msg_id']

    def history(self,
                raw=True,
                output=False,
                hist_access_type='range',
                **kwargs):
        """Get entries from the kernel's history list.

        Parameters
        ----------
        raw : bool
            If True, return the raw input.
        output : bool
            If True, then return the output as well.
        hist_access_type : str
            'range' (fill in session, start and stop params), 'tail' (fill in n)
             or 'search' (fill in pattern param).

        session : int
            For a range request, the session from which to get lines. Session
            numbers are positive integers; negative ones count back from the
            current session.
        start : int
            The first line number of a history range.
        stop : int
            The final (excluded) line number of a history range.

        n : int
            The number of lines of history to get for a tail request.

        pattern : str
            The glob-syntax pattern for a search request.

        Returns
        -------
        The ID of the message sent.
        """
        if hist_access_type == 'range':
            kwargs.setdefault('session', 0)
            kwargs.setdefault('start', 0)
        content = dict(raw=raw,
                       output=output,
                       hist_access_type=hist_access_type,
                       **kwargs)
        msg = self.session.msg('history_request', content)
        self.shell_channel.send(msg)
        return msg['header']['msg_id']

    def kernel_info(self):
        """Request kernel info

        Returns
        -------
        The msg_id of the message sent
        """
        msg = self.session.msg('kernel_info_request')
        self.shell_channel.send(msg)
        return msg['header']['msg_id']

    def comm_info(self, target_name=None):
        """Request comm info

        Returns
        -------
        The msg_id of the message sent
        """
        if target_name is None:
            content = {}
        else:
            content = dict(target_name=target_name)
        msg = self.session.msg('comm_info_request', content)
        self.shell_channel.send(msg)
        return msg['header']['msg_id']

    def _handle_kernel_info_reply(self, msg):
        """handle kernel info reply

        sets protocol adaptation version. This might
        be run from a separate thread.
        """
        adapt_version = int(msg['content']['protocol_version'].split('.')[0])
        if adapt_version != major_protocol_version:
            self.session.adapt_version = adapt_version

    def is_complete(self, code):
        """Ask the kernel whether some code is complete and ready to execute."""
        msg = self.session.msg('is_complete_request', {'code': code})
        self.shell_channel.send(msg)
        return msg['header']['msg_id']

    def input(self, string):
        """Send a string of raw input to the kernel.

        This should only be called in response to the kernel sending an
        ``input_request`` message on the stdin channel.
        """
        content = dict(value=string)
        msg = self.session.msg('input_reply', content)
        self.stdin_channel.send(msg)

    def shutdown(self, restart=False):
        """Request an immediate kernel shutdown on the control channel.

        Upon receipt of the (empty) reply, client code can safely assume that
        the kernel has shut down and it's safe to forcefully terminate it if
        it's still alive.

        The kernel will send the reply via a function registered with Python's
        atexit module, ensuring it's truly done as the kernel is done with all
        normal operation.

        Returns
        -------
        The msg_id of the message sent
        """
        # Send quit message to kernel. Once we implement kernel-side setattr,
        # this should probably be done that way, but for now this will do.
        msg = self.session.msg('shutdown_request', {'restart': restart})
        self.control_channel.send(msg)
        return msg['header']['msg_id']
示例#4
0
class IPythonKernel(KernelBase):
    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
                     allow_none=True)
    shell_class = Type(ZMQInteractiveShell)

    user_module = Any()
    def _user_module_changed(self, name, old, new):
        if self.shell is not None:
            self.shell.user_module = new

    user_ns = Instance(dict, args=None, allow_none=True)
    def _user_ns_changed(self, name, old, new):
        if self.shell is not None:
            self.shell.user_ns = new
            self.shell.init_user_ns()

    # A reference to the Python builtin 'raw_input' function.
    # (i.e., __builtin__.raw_input for Python 2.7, builtins.input for Python 3)
    _sys_raw_input = Any()
    _sys_eval_input = Any()

    def __init__(self, **kwargs):
        super(IPythonKernel, self).__init__(**kwargs)

        # Initialize the InteractiveShell subclass
        self.shell = self.shell_class.instance(parent=self,
            profile_dir = self.profile_dir,
            user_module = self.user_module,
            user_ns     = self.user_ns,
            kernel      = self,
        )
        self.shell.displayhook.session = self.session
        self.shell.displayhook.pub_socket = self.iopub_socket
        self.shell.displayhook.topic = self._topic('execute_result')
        self.shell.display_pub.session = self.session
        self.shell.display_pub.pub_socket = self.iopub_socket

        self.comm_manager = CommManager(parent=self, kernel=self)

        self.shell.configurables.append(self.comm_manager)
        comm_msg_types = [ 'comm_open', 'comm_msg', 'comm_close' ]
        for msg_type in comm_msg_types:
            self.shell_handlers[msg_type] = getattr(self.comm_manager, msg_type)

    help_links = List([
        {
            'text': "Python",
            'url': "http://docs.python.org/%i.%i" % sys.version_info[:2],
        },
        {
            'text': "IPython",
            'url': "http://ipython.org/documentation.html",
        },
        {
            'text': "NumPy",
            'url': "http://docs.scipy.org/doc/numpy/reference/",
        },
        {
            'text': "SciPy",
            'url': "http://docs.scipy.org/doc/scipy/reference/",
        },
        {
            'text': "Matplotlib",
            'url': "http://matplotlib.org/contents.html",
        },
        {
            'text': "SymPy",
            'url': "http://docs.sympy.org/latest/index.html",
        },
        {
            'text': "pandas",
            'url': "http://pandas.pydata.org/pandas-docs/stable/",
        },
        {
            'text': "Python Brackets",
            'url': "http://docs.python-brackets.org",
        },
    ]).tag(config=True)

    # Kernel info fields
    implementation = 'ibpython'
    implementation_version = release.version
    language_info = {
        'name': 'ibpython',
        'version': sys.version.split()[0],
        'mimetype': 'text/x-python',
        'codemirror_mode': {
            'name': 'ipython',
            'version': sys.version_info[0]
        },
        'pygments_lexer': 'ipython%d' % (3 if PY3 else 2),
        'nbconvert_exporter': 'python',
        'file_extension': '.bpy'
    }

    @property
    def banner(self):
        return brackets.shell.BANNER

    def start(self):
        self.shell.exit_now = False
        super(IPythonKernel, self).start()

    def set_parent(self, ident, parent):
        """Overridden from parent to tell the display hook and output streams
        about the parent message.
        """
        super(IPythonKernel, self).set_parent(ident, parent)
        self.shell.set_parent(parent)

    def init_metadata(self, parent):
        """Initialize metadata.

        Run at the beginning of each execution request.
        """
        md = super(IPythonKernel, self).init_metadata(parent)
        # FIXME: remove deprecated ipyparallel-specific code
        # This is required for ipyparallel < 5.0
        md.update({
            'dependencies_met' : True,
            'engine' : self.ident,
        })
        return md

    def finish_metadata(self, parent, metadata, reply_content):
        """Finish populating metadata.

        Run after completing an execution request.
        """
        # FIXME: remove deprecated ipyparallel-specific code
        # This is required by ipyparallel < 5.0
        metadata['status'] = reply_content['status']
        if reply_content['status'] == 'error' and reply_content['ename'] == 'UnmetDependency':
                metadata['dependencies_met'] = False

        return metadata

    def _forward_input(self, allow_stdin=False):
        """Forward raw_input and getpass to the current frontend.

        via input_request
        """
        self._allow_stdin = allow_stdin

        if PY3:
            self._sys_raw_input = builtin_mod.input
            builtin_mod.input = self.raw_input
        else:
            self._sys_raw_input = builtin_mod.raw_input
            self._sys_eval_input = builtin_mod.input
            builtin_mod.raw_input = self.raw_input
            builtin_mod.input = lambda prompt='': eval(self.raw_input(prompt))
        self._save_getpass = getpass.getpass
        getpass.getpass = self.getpass

    def _restore_input(self):
        """Restore raw_input, getpass"""
        if PY3:
            builtin_mod.input = self._sys_raw_input
        else:
            builtin_mod.raw_input = self._sys_raw_input
            builtin_mod.input = self._sys_eval_input

        getpass.getpass = self._save_getpass

    @property
    def execution_count(self):
        return self.shell.execution_count

    @execution_count.setter
    def execution_count(self, value):
        # Ignore the incrememnting done by KernelBase, in favour of our shell's
        # execution counter.
        pass

    def do_execute(self, code, silent, store_history=True,
                   user_expressions=None, allow_stdin=False):
        shell = self.shell # we'll need this a lot here

        self._forward_input(allow_stdin)

        reply_content = {}

        try:
            code = brackets.translate(code)
            res = shell.run_cell(code, store_history=store_history, silent=silent)
        finally:
            self._restore_input()

        if res.error_before_exec is not None:
            err = res.error_before_exec
        else:
            err = res.error_in_exec

        if res.success:
            reply_content[u'status'] = u'ok'
        else:
            reply_content[u'status'] = u'error'

            reply_content.update({
                u'traceback': shell._last_traceback or [],
                u'ename': unicode_type(type(err).__name__),
                u'evalue': safe_unicode(err),
            })

            # FIXME: deprecated piece for ipyparallel (remove in 5.0):
            e_info = dict(engine_uuid=self.ident, engine_id=self.int_id,
                          method='execute')
            reply_content['engine_info'] = e_info


        # Return the execution counter so clients can display prompts
        reply_content['execution_count'] = shell.execution_count - 1

        if 'traceback' in reply_content:
            self.log.info("Exception in execute request:\n%s", '\n'.join(reply_content['traceback']))


        # At this point, we can tell whether the main code execution succeeded
        # or not.  If it did, we proceed to evaluate user_expressions
        if reply_content['status'] == 'ok':
            reply_content[u'user_expressions'] = \
                         shell.user_expressions(user_expressions or {})
        else:
            # If there was an error, don't even try to compute expressions
            reply_content[u'user_expressions'] = {}

        # Payloads should be retrieved regardless of outcome, so we can both
        # recover partial output (that could have been generated early in a
        # block, before an error) and always clear the payload system.
        reply_content[u'payload'] = shell.payload_manager.read_payload()
        # Be aggressive about clearing the payload because we don't want
        # it to sit in memory until the next execute_request comes in.
        shell.payload_manager.clear_payload()

        return reply_content

    def do_complete(self, code, cursor_pos):
        if _use_experimental_60_completion:
            return self._experimental_do_complete(code, cursor_pos)

        # FIXME: IPython completers currently assume single line,
        # but completion messages give multi-line context
        # For now, extract line from cell, based on cursor_pos:
        if cursor_pos is None:
            cursor_pos = len(code)
        line, offset = line_at_cursor(code, cursor_pos)
        line_cursor = cursor_pos - offset

        txt, matches = self.shell.complete('', line, line_cursor)
        return {'matches' : matches,
                'cursor_end' : cursor_pos,
                'cursor_start' : cursor_pos - len(txt),
                'metadata' : {},
                'status' : 'ok'}

    def _experimental_do_complete(self, code, cursor_pos):
        """
        Experimental completions from IPython, using Jedi.
        """
        if cursor_pos is None:
            cursor_pos = len(code)
        with provisionalcompleter():
            raw_completions = self.shell.Completer.completions(code, cursor_pos)
            completions = list(rectify_completions(code, raw_completions))

            comps = []
            for comp in completions:
                comps.append(dict(
                            start=comp.start,
                            end=comp.end,
                            text=comp.text,
                            type=comp.type,
                ))

        if completions:
            s = completions[0].start
            e = completions[0].end
            matches = [c.text for c in completions]
        else:
            s = cursor_pos
            e = cursor_pos
            matches = []

        return {'matches': matches,
                'cursor_end': e,
                'cursor_start': s,
                'metadata': {_EXPERIMENTAL_KEY_NAME: comps},
                'status': 'ok'}



    def do_inspect(self, code, cursor_pos, detail_level=0):
        name = token_at_cursor(code, cursor_pos)
        info = self.shell.object_inspect(name)

        reply_content = {'status' : 'ok'}
        reply_content['data'] = data = {}
        reply_content['metadata'] = {}
        reply_content['found'] = info['found']
        if info['found']:
            info_text = self.shell.object_inspect_text(
                name,
                detail_level=detail_level,
            )
            data['text/plain'] = info_text

        return reply_content

    def do_history(self, hist_access_type, output, raw, session=0, start=0,
                   stop=None, n=None, pattern=None, unique=False):
        if hist_access_type == 'tail':
            hist = self.shell.history_manager.get_tail(n, raw=raw, output=output,
                                                            include_latest=True)

        elif hist_access_type == 'range':
            hist = self.shell.history_manager.get_range(session, start, stop,
                                                        raw=raw, output=output)

        elif hist_access_type == 'search':
            hist = self.shell.history_manager.search(
                pattern, raw=raw, output=output, n=n, unique=unique)
        else:
            hist = []

        return {
            'status': 'ok',
            'history' : list(hist),
        }

    def do_shutdown(self, restart):
        self.shell.exit_now = True
        return dict(status='ok', restart=restart)

    def do_is_complete(self, code):
        status, indent_spaces = self.shell.input_transformer_manager.check_complete(code)
        r = {'status': status}

        if status == 'incomplete':
            r['indent'] = ' ' * indent_spaces

        try:
            code = brackets.translate(code)
        except:
            r = {'status': 'incomplete', 'indent': '    '}

        return r

    def do_apply(self, content, bufs, msg_id, reply_metadata):
        from .serialize import serialize_object, unpack_apply_message
        shell = self.shell
        try:
            working = shell.user_ns

            prefix = "_"+str(msg_id).replace("-","")+"_"

            f,args,kwargs = unpack_apply_message(bufs, working, copy=False)

            fname = getattr(f, '__name__', 'f')

            fname = prefix+"f"
            argname = prefix+"args"
            kwargname = prefix+"kwargs"
            resultname = prefix+"result"

            ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
            # print ns
            working.update(ns)
            code = "%s = %s(*%s,**%s)" % (resultname, fname, argname, kwargname)
            try:
                exec(code, shell.user_global_ns, shell.user_ns)
                result = working.get(resultname)
            finally:
                for key in ns:
                    working.pop(key)

            result_buf = serialize_object(result,
                buffer_threshold=self.session.buffer_threshold,
                item_threshold=self.session.item_threshold,
            )

        except BaseException as e:
            # invoke IPython traceback formatting
            shell.showtraceback()
            reply_content = {
                u'traceback': shell._last_traceback or [],
                u'ename': unicode_type(type(e).__name__),
                u'evalue': safe_unicode(e),
            }
            # FIXME: deprecated piece for ipyparallel (remove in 5.0):
            e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method='apply')
            reply_content['engine_info'] = e_info

            self.send_response(self.iopub_socket, u'error', reply_content,
                                ident=self._topic('error'))
            self.log.info("Exception in apply request:\n%s", '\n'.join(reply_content['traceback']))
            result_buf = []
            reply_content['status'] = 'error'
        else:
            reply_content = {'status' : 'ok'}

        return reply_content, result_buf

    def do_clear(self):
        self.shell.reset(False)
        return dict(status='ok')
示例#5
0
class NbConvertApp(JupyterApp):
    """Application used to convert from notebook file type (``*.ipynb``)"""

    version = __version__
    name = 'jupyter-nbconvert'
    aliases = nbconvert_aliases
    flags = nbconvert_flags

    def _log_level_default(self):
        return logging.INFO

    classes = List()

    def _classes_default(self):
        classes = [NbConvertBase]
        for pkg in (exporters, preprocessors, writers, postprocessors):
            for name in dir(pkg):
                cls = getattr(pkg, name)
                if isinstance(cls, type) and issubclass(cls, Configurable):
                    classes.append(cls)

        return classes

    description = Unicode(
        u"""This application is used to convert notebook files (*.ipynb)
        to various other formats.

        WARNING: THE COMMANDLINE INTERFACE MAY CHANGE IN FUTURE RELEASES.""")

    output_base = Unicode('',
                          config=True,
                          help='''overwrite base name use for output files.
            can only be used when converting one notebook at a time.
            ''')

    use_output_suffix = Bool(
        True,
        config=True,
        help="""Whether to apply a suffix prior to the extension (only relevant
            when converting to notebook format). The suffix is determined by
            the exporter, and is usually '.nbconvert'.""")

    examples = Unicode(u"""
        The simplest way to use nbconvert is
        
        > jupyter nbconvert mynotebook.ipynb
        
        which will convert mynotebook.ipynb to the default format (probably HTML).
        
        You can specify the export format with `--to`.
        Options include {0}
        
        > jupyter nbconvert --to latex mynotebook.ipynb

        Both HTML and LaTeX support multiple output templates. LaTeX includes
        'base', 'article' and 'report'.  HTML includes 'basic' and 'full'. You
        can specify the flavor of the format used.

        > jupyter nbconvert --to html --template basic mynotebook.ipynb
        
        You can also pipe the output to stdout, rather than a file
        
        > jupyter nbconvert mynotebook.ipynb --stdout

        PDF is generated via latex

        > jupyter nbconvert mynotebook.ipynb --to pdf
        
        You can get (and serve) a Reveal.js-powered slideshow
        
        > jupyter nbconvert myslides.ipynb --to slides --post serve
        
        Multiple notebooks can be given at the command line in a couple of 
        different ways:
  
        > jupyter nbconvert notebook*.ipynb
        > jupyter nbconvert notebook1.ipynb notebook2.ipynb
        
        or you can specify the notebooks list in a config file, containing::
        
            c.NbConvertApp.notebooks = ["my_notebook.ipynb"]
        
        > jupyter nbconvert --config mycfg.py
        """.format(get_export_names()))

    # Writer specific variables
    writer = Instance('nbconvert.writers.base.WriterBase',
                      help="""Instance of the writer class used to write the 
                      results of the conversion.""",
                      allow_none=True)
    writer_class = DottedObjectName('FilesWriter',
                                    config=True,
                                    help="""Writer class used to write the 
                                    results of the conversion""")
    writer_aliases = {
        'fileswriter': 'nbconvert.writers.files.FilesWriter',
        'debugwriter': 'nbconvert.writers.debug.DebugWriter',
        'stdoutwriter': 'nbconvert.writers.stdout.StdoutWriter'
    }
    writer_factory = Type(allow_none=True)

    def _writer_class_changed(self, name, old, new):
        if new.lower() in self.writer_aliases:
            new = self.writer_aliases[new.lower()]
        self.writer_factory = import_item(new)

    # Post-processor specific variables
    postprocessor = Instance(
        'nbconvert.postprocessors.base.PostProcessorBase',
        help="""Instance of the PostProcessor class used to write the
                      results of the conversion.""",
        allow_none=True)

    postprocessor_class = DottedOrNone(
        config=True,
        help="""PostProcessor class used to write the
                                    results of the conversion""")
    postprocessor_aliases = {
        'serve': 'nbconvert.postprocessors.serve.ServePostProcessor'
    }
    postprocessor_factory = Type(None, allow_none=True)

    def _postprocessor_class_changed(self, name, old, new):
        if new.lower() in self.postprocessor_aliases:
            new = self.postprocessor_aliases[new.lower()]
        if new:
            self.postprocessor_factory = import_item(new)

    export_format = Unicode(
        'html',
        allow_none=False,
        config=True,
        help="""The export format to be used, either one of the built-in formats,
        or a dotted object name that represents the import path for an
        `Exporter` class""")

    notebooks = List([],
                     config=True,
                     help="""List of notebooks to convert.
                     Wildcards are supported.
                     Filenames passed positionally will be added to the list.
                     """)
    from_stdin = Bool(False,
                      config=True,
                      help="read a single notebook from stdin.")

    @catch_config_error
    def initialize(self, argv=None):
        self.init_syspath()
        super(NbConvertApp, self).initialize(argv)
        self.init_notebooks()
        self.init_writer()
        self.init_postprocessor()

    def init_syspath(self):
        """
        Add the cwd to the sys.path ($PYTHONPATH)
        """
        sys.path.insert(0, os.getcwd())

    def init_notebooks(self):
        """Construct the list of notebooks.
        If notebooks are passed on the command-line,
        they override notebooks specified in config files.
        Glob each notebook to replace notebook patterns with filenames.
        """

        # Specifying notebooks on the command-line overrides (rather than adds)
        # the notebook list
        if self.extra_args:
            patterns = self.extra_args
        else:
            patterns = self.notebooks

        # Use glob to replace all the notebook patterns with filenames.
        filenames = []
        for pattern in patterns:

            # Use glob to find matching filenames.  Allow the user to convert
            # notebooks without having to type the extension.
            globbed_files = glob.glob(pattern)
            globbed_files.extend(glob.glob(pattern + '.ipynb'))
            if not globbed_files:
                self.log.warn("pattern %r matched no files", pattern)

            for filename in globbed_files:
                if not filename in filenames:
                    filenames.append(filename)
        self.notebooks = filenames

    def init_writer(self):
        """
        Initialize the writer (which is stateless)
        """
        self._writer_class_changed(None, self.writer_class, self.writer_class)
        self.writer = self.writer_factory(parent=self)
        if hasattr(self.writer,
                   'build_directory') and self.writer.build_directory != '':
            self.use_output_suffix = False

    def init_postprocessor(self):
        """
        Initialize the postprocessor (which is stateless)
        """
        self._postprocessor_class_changed(None, self.postprocessor_class,
                                          self.postprocessor_class)
        if self.postprocessor_factory:
            self.postprocessor = self.postprocessor_factory(parent=self)

    def start(self):
        """
        Ran after initialization completed
        """
        super(NbConvertApp, self).start()
        self.convert_notebooks()

    def init_single_notebook_resources(self, notebook_filename):
        """Step 1: Initialize resources

        This intializes the resources dictionary for a single notebook. This
        method should return the resources dictionary, and MUST include the
        following keys:

            - config_dir: the location of the Jupyter config directory
            - unique_key: the notebook name
            - output_files_dir: a directory where output files (not including
              the notebook itself) should be saved

        """

        basename = os.path.basename(notebook_filename)
        notebook_name = basename[:basename.rfind('.')]
        if self.output_base:
            # strip duplicate extension from output_base, to avoid Basname.ext.ext
            if getattr(self.exporter, 'file_extension', False):
                base, ext = os.path.splitext(self.output_base)
                if ext == self.exporter.file_extension:
                    self.output_base = base
            notebook_name = self.output_base

        self.log.debug("Notebook name is '%s'", notebook_name)

        # first initialize the resources we want to use
        resources = {}
        resources['config_dir'] = self.config_dir
        resources['unique_key'] = notebook_name
        resources['output_files_dir'] = '%s_files' % notebook_name

        return resources

    def export_single_notebook(self,
                               notebook_filename,
                               resources,
                               input_buffer=None):
        """Step 2: Export the notebook

        Exports the notebook to a particular format according to the specified
        exporter. This function returns the output and (possibly modified)
        resources from the exporter.

        notebook_filename: a filename
        input_buffer: a readable file like object returning unicode, if not None notebook_filename is ignored
        """
        try:
            if input_buffer is not None:
                output, resources = self.exporter.from_file(
                    input_buffer, resources=resources)
            else:
                output, resources = self.exporter.from_filename(
                    notebook_filename, resources=resources)
        except ConversionException:
            self.log.error("Error while converting '%s'",
                           notebook_filename,
                           exc_info=True)
            self.exit(1)

        return output, resources

    def write_single_notebook(self, output, resources):
        """Step 3: Write the notebook to file

        This writes output from the exporter to file using the specified writer.
        It returns the results from the writer.

        """
        if 'unique_key' not in resources:
            raise KeyError(
                "unique_key MUST be specified in the resources, but it is not")

        notebook_name = resources['unique_key']
        if self.use_output_suffix and not self.output_base:
            notebook_name += resources.get('output_suffix', '')

        write_results = self.writer.write(output,
                                          resources,
                                          notebook_name=notebook_name)
        return write_results

    def postprocess_single_notebook(self, write_results):
        """Step 4: Postprocess the notebook

        This postprocesses the notebook after it has been written, taking as an
        argument the results of writing the notebook to file. This only actually
        does anything if a postprocessor has actually been specified.

        """
        # Post-process if post processor has been defined.
        if hasattr(self, 'postprocessor') and self.postprocessor:
            self.postprocessor(write_results)

    def convert_single_notebook(self, notebook_filename, input_buffer=None):
        """Convert a single notebook. Performs the following steps:

            1. Initialize notebook resources
            2. Export the notebook to a particular format
            3. Write the exported notebook to file
            4. (Maybe) postprocess the written file
            
            If input_buffer is not None, convertion is done using buffer as source into
            a file basenamed by the notebook_filename argument.
        """
        if input_buffer is None:
            self.log.info("Converting notebook %s to %s", notebook_filename,
                          self.export_format)
        else:
            self.log.info("Converting notebook into %s", self.export_format)

        resources = self.init_single_notebook_resources(notebook_filename)
        output, resources = self.export_single_notebook(
            notebook_filename, resources, input_buffer=input_buffer)
        write_results = self.write_single_notebook(output, resources)
        self.postprocess_single_notebook(write_results)

    def convert_notebooks(self):
        """
        Convert the notebooks in the self.notebook traitlet
        """
        # check that the output base isn't specified if there is more than
        # one notebook to convert
        if self.output_base != '' and len(self.notebooks) > 1:
            self.log.error("""
                UsageError: --output flag or `NbConvertApp.output_base` config option
                cannot be used when converting multiple notebooks.
                """)
            self.exit(1)

        # initialize the exporter
        cls = get_exporter(self.export_format)
        self.exporter = cls(config=self.config)

        # no notebooks to convert!
        if len(self.notebooks) == 0 and not self.from_stdin:
            self.print_help()
            sys.exit(-1)

        # convert each notebook
        if not self.from_stdin:
            for notebook_filename in self.notebooks:
                self.convert_single_notebook(notebook_filename)
        else:
            input_buffer = unicode_stdin_stream()
            # default name when conversion from stdin
            self.convert_single_notebook("notebook.ipynb",
                                         input_buffer=input_buffer)
class DatsciSpawner(WrapSpawner):
    """DatsciSpawner - custom classes on the wrapper to provide several spawn
       profiles meeting different business needs.
    """

    profiles = List(
        trait=Tuple(Unicode(), Unicode(), Type(Spawner), Dict()),
        default_value=[('Default', 'kube0', KubeSpawner, {
            'start_timeout': 15,
            'http_timeout': 10
        })],
        minlen=1,
        config=True,
        help="""List of profiles to offer for selection. Signature is:
            List(Tuple( Unicode, Unicode, Type(Spawner), Dict )) corresponding to
            profile display name, unique key, Spawner class, dictionary of spawner config options.

            The first three values will be exposed in the input_template as {display}, {key}, and {type}"""
    )

    child_profile = Unicode()

    form_template = Unicode(
        """<label for="profile">Select a job profile:</label>
        <select class="form-control" name="profile" required autofocus>
        {input_template}
        </select>
        """,
        config=True,
        help=
        """Template to use to construct options_form text. {input_template} is replaced with
            the result of formatting input_template against each item in the profiles list."""
    )

    first_template = Unicode(
        'selected',
        config=True,
        help="Text to substitute as {first} in input_template")

    input_template = Unicode(
        """
        <option value="{key}" {first}>{display}</option>""",
        config=True,
        help=
        """Template to construct {input_template} in form_template. This text will be formatted
            against each item in the profiles list, in order, using the following key names:
            ( display, key, type ) for the first three items in the tuple, and additionally
            first = "checked" (taken from first_template) for the first item in the list, so that
            the first item starts selected.""")

    options_form = Unicode()

    def _options_form_default(self):
        temp_keys = [
            dict(display=p[0], key=p[1], type=p[2], first='')
            for p in self.profiles
        ]
        temp_keys[0]['first'] = self.first_template
        text = ''.join([self.input_template.format(**tk) for tk in temp_keys])
        return self.form_template.format(input_template=text)

    def options_from_form(self, formdata):
        # Default to first profile if somehow none is provided
        return dict(profile=formdata.get('profile', [self.profiles[0][1]])[0])

    def select_profile(self, profile):
        # Select matching profile, or do nothing (leaving previous or default config in place)
        for p in self.profiles:
            if p[1] == profile:
                self.child_class = p[2]
                self.child_config = p[3]
                break

    def construct_child(self):
        self.child_profile = self.user_options.get('profile', "")
        self.select_profile(self.child_profile)
        super().construct_child()

    def load_child_class(self, state):
        try:
            self.child_profile = state['profile']
        except KeyError:
            raise KeyError(
                'jupyterhub database might be outdated, please reset it, in the default configuration, just delete jupyterhub.sqlite'
            )
        self.select_profile(self.child_profile)

    def get_state(self):
        state = super().get_state()
        state['profile'] = self.child_profile
        return state

    def clear_state(self):
        super().clear_state()
        self.child_profile = ''
示例#7
0
class ContentsManager(LoggingConfigurable):
    """Base class for serving files and directories.

    This serves any text or binary file,
    as well as directories,
    with special handling for JSON notebook documents.

    Most APIs take a path argument,
    which is always an API-style unicode path,
    and always refers to a directory.

    - unicode, not url-escaped
    - '/'-separated
    - leading and trailing '/' will be stripped
    - if unspecified, path defaults to '',
      indicating the root path.

    """
    
    root_dir = Unicode('/', config=True)

    notary = Instance(sign.NotebookNotary)
    def _notary_default(self):
        return sign.NotebookNotary(parent=self)

    hide_globs = List(Unicode(), [
            u'__pycache__', '*.pyc', '*.pyo',
            '.DS_Store', '*.so', '*.dylib', '*~',
        ], config=True, help="""
        Glob patterns to hide in file and directory listings.
    """)

    untitled_notebook = Unicode(_("Untitled"), config=True,
        help="The base name used when creating untitled notebooks."
    )

    untitled_file = Unicode("untitled", config=True,
        help="The base name used when creating untitled files."
    )

    untitled_directory = Unicode("Untitled Folder", config=True,
        help="The base name used when creating untitled directories."
    )

    pre_save_hook = Any(None, config=True, allow_none=True,
        help="""Python callable or importstring thereof

        To be called on a contents model prior to save.

        This can be used to process the structure,
        such as removing notebook outputs or other side effects that
        should not be saved.

        It will be called as (all arguments passed by keyword)::

            hook(path=path, model=model, contents_manager=self)

        - model: the model to be saved. Includes file contents.
          Modifying this dict will affect the file that is stored.
        - path: the API path of the save destination
        - contents_manager: this ContentsManager instance
        """
    )

    @validate('pre_save_hook')
    def _validate_pre_save_hook(self, proposal):
        value = proposal['value']
        if isinstance(value, string_types):
            value = import_item(self.pre_save_hook)
        if not callable(value):
            raise TraitError("pre_save_hook must be callable")
        return value

    def run_pre_save_hook(self, model, path, **kwargs):
        """Run the pre-save hook if defined, and log errors"""
        if self.pre_save_hook:
            try:
                self.log.debug("Running pre-save hook on %s", path)
                self.pre_save_hook(model=model, path=path, contents_manager=self, **kwargs)
            except Exception:
                self.log.error("Pre-save hook failed on %s", path, exc_info=True)

    checkpoints_class = Type(Checkpoints, config=True)
    checkpoints = Instance(Checkpoints, config=True)
    checkpoints_kwargs = Dict(config=True)

    @default('checkpoints')
    def _default_checkpoints(self):
        return self.checkpoints_class(**self.checkpoints_kwargs)

    @default('checkpoints_kwargs')
    def _default_checkpoints_kwargs(self):
        return dict(
            parent=self,
            log=self.log,
        )

    files_handler_class = Type(
        FilesHandler, klass=RequestHandler, allow_none=True, config=True,
        help="""handler class to use when serving raw file requests.

        Default is a fallback that talks to the ContentsManager API,
        which may be inefficient, especially for large files.

        Local files-based ContentsManagers can use a StaticFileHandler subclass,
        which will be much more efficient.

        Access to these files should be Authenticated.
        """
    )

    files_handler_params = Dict(
        config=True,
        help="""Extra parameters to pass to files_handler_class.

        For example, StaticFileHandlers generally expect a `path` argument
        specifying the root directory from which to serve files.
        """
    )

    def get_extra_handlers(self):
        """Return additional handlers

        Default: self.files_handler_class on /files/.*
        """
        handlers = []
        if self.files_handler_class:
            handlers.append(
                (r"/files/(.*)", self.files_handler_class, self.files_handler_params)
            )
        return handlers

    # ContentsManager API part 1: methods that must be
    # implemented in subclasses.

    def dir_exists(self, path):
        """Does a directory exist at the given path?

        Like os.path.isdir

        Override this method in subclasses.

        Parameters
        ----------
        path : string
            The path to check

        Returns
        -------
        exists : bool
            Whether the path does indeed exist.
        """
        raise NotImplementedError

    def is_hidden(self, path):
        """Is path a hidden directory or file?

        Parameters
        ----------
        path : string
            The path to check. This is an API path (`/` separated,
            relative to root dir).

        Returns
        -------
        hidden : bool
            Whether the path is hidden.

        """
        raise NotImplementedError

    def file_exists(self, path=''):
        """Does a file exist at the given path?

        Like os.path.isfile

        Override this method in subclasses.

        Parameters
        ----------
        path : string
            The API path of a file to check for.

        Returns
        -------
        exists : bool
            Whether the file exists.
        """
        raise NotImplementedError('must be implemented in a subclass')

    def exists(self, path):
        """Does a file or directory exist at the given path?

        Like os.path.exists

        Parameters
        ----------
        path : string
            The API path of a file or directory to check for.

        Returns
        -------
        exists : bool
            Whether the target exists.
        """
        return self.file_exists(path) or self.dir_exists(path)

    def get(self, path, content=True, type=None, format=None):
        """Get a file or directory model."""
        raise NotImplementedError('must be implemented in a subclass')

    def save(self, model, path):
        """
        Save a file or directory model to path.

        Should return the saved model with no content.  Save implementations
        should call self.run_pre_save_hook(model=model, path=path) prior to
        writing any data.
        """
        raise NotImplementedError('must be implemented in a subclass')

    def delete_file(self, path):
        """Delete the file or directory at path."""
        raise NotImplementedError('must be implemented in a subclass')

    def rename_file(self, old_path, new_path):
        """Rename a file or directory."""
        raise NotImplementedError('must be implemented in a subclass')

    # ContentsManager API part 2: methods that have useable default
    # implementations, but can be overridden in subclasses.

    def delete(self, path):
        """Delete a file/directory and any associated checkpoints."""
        path = path.strip('/')
        if not path:
            raise HTTPError(400, "Can't delete root")
        self.delete_file(path)
        self.checkpoints.delete_all_checkpoints(path)

    def rename(self, old_path, new_path):
        """Rename a file and any checkpoints associated with that file."""
        self.rename_file(old_path, new_path)
        self.checkpoints.rename_all_checkpoints(old_path, new_path)

    def update(self, model, path):
        """Update the file's path

        For use in PATCH requests, to enable renaming a file without
        re-uploading its contents. Only used for renaming at the moment.
        """
        path = path.strip('/')
        new_path = model.get('path', path).strip('/')
        if path != new_path:
            self.rename(path, new_path)
        model = self.get(new_path, content=False)
        return model

    def info_string(self):
        return "Serving contents"

    def get_kernel_path(self, path, model=None):
        """Return the API path for the kernel
        
        KernelManagers can turn this value into a filesystem path,
        or ignore it altogether.

        The default value here will start kernels in the directory of the
        notebook server. FileContentsManager overrides this to use the
        directory containing the notebook.
        """
        return ''

    def increment_filename(self, filename, path='', insert=''):
        """Increment a filename until it is unique.

        Parameters
        ----------
        filename : unicode
            The name of a file, including extension
        path : unicode
            The API path of the target's directory

        Returns
        -------
        name : unicode
            A filename that is unique, based on the input filename.
        """
        path = path.strip('/')
        basename, ext = os.path.splitext(filename)
        for i in itertools.count():
            if i:
                insert_i = '{}{}'.format(insert, i)
            else:
                insert_i = ''
            name = u'{basename}{insert}{ext}'.format(basename=basename,
                insert=insert_i, ext=ext)
            if not self.exists(u'{}/{}'.format(path, name)):
                break
        return name

    def validate_notebook_model(self, model):
        """Add failed-validation message to model"""
        try:
            validate_nb(model['content'])
        except ValidationError as e:
            model['message'] = u'Notebook validation failed: {}:\n{}'.format(
                e.message, json.dumps(e.instance, indent=1, default=lambda obj: '<UNKNOWN>'),
            )
        return model
    
    def new_untitled(self, path='', type='', ext=''):
        """Create a new untitled file or directory in path
        
        path must be a directory
        
        File extension can be specified.
        
        Use `new` to create files with a fully specified path (including filename).
        """
        path = path.strip('/')
        if not self.dir_exists(path):
            raise HTTPError(404, 'No such directory: %s' % path)
        
        model = {}
        if type:
            model['type'] = type
        
        if ext == '.ipynb':
            model.setdefault('type', 'notebook')
        else:
            model.setdefault('type', 'file')
        
        insert = ''
        if model['type'] == 'directory':
            untitled = self.untitled_directory
            insert = ' '
        elif model['type'] == 'notebook':
            untitled = self.untitled_notebook
            ext = '.ipynb'
        elif model['type'] == 'file':
            untitled = self.untitled_file
        else:
            raise HTTPError(400, "Unexpected model type: %r" % model['type'])
        
        name = self.increment_filename(untitled + ext, path, insert=insert)
        path = u'{0}/{1}'.format(path, name)
        return self.new(model, path)
    
    def new(self, model=None, path=''):
        """Create a new file or directory and return its model with no content.
        
        To create a new untitled entity in a directory, use `new_untitled`.
        """
        path = path.strip('/')
        if model is None:
            model = {}
        
        if path.endswith('.ipynb'):
            model.setdefault('type', 'notebook')
        else:
            model.setdefault('type', 'file')
        
        # no content, not a directory, so fill out new-file model
        if 'content' not in model and model['type'] != 'directory':
            if model['type'] == 'notebook':
                model['content'] = new_notebook()
                model['format'] = 'json'
            else:
                model['content'] = ''
                model['type'] = 'file'
                model['format'] = 'text'
        
        model = self.save(model, path)
        return model

    def copy(self, from_path, to_path=None):
        """Copy an existing file and return its new model.

        If to_path not specified, it will be the parent directory of from_path.
        If to_path is a directory, filename will increment `from_path-Copy#.ext`.

        from_path must be a full path to a file.
        """
        path = from_path.strip('/')
        if to_path is not None:
            to_path = to_path.strip('/')

        if '/' in path:
            from_dir, from_name = path.rsplit('/', 1)
        else:
            from_dir = ''
            from_name = path
        
        model = self.get(path)
        model.pop('path', None)
        model.pop('name', None)
        if model['type'] == 'directory':
            raise HTTPError(400, "Can't copy directories")
        
        if to_path is None:
            to_path = from_dir
        if self.dir_exists(to_path):
            name = copy_pat.sub(u'.', from_name)
            to_name = self.increment_filename(name, to_path, insert='-Copy')
            to_path = u'{0}/{1}'.format(to_path, to_name)
        
        model = self.save(model, to_path)
        return model

    def log_info(self):
        self.log.info(self.info_string())

    def trust_notebook(self, path):
        """Explicitly trust a notebook

        Parameters
        ----------
        path : string
            The path of a notebook
        """
        model = self.get(path)
        nb = model['content']
        self.log.warning("Trusting notebook %s", path)
        self.notary.mark_cells(nb, True)
        self.check_and_sign(nb, path)

    def check_and_sign(self, nb, path=''):
        """Check for trusted cells, and sign the notebook.

        Called as a part of saving notebooks.

        Parameters
        ----------
        nb : dict
            The notebook dict
        path : string
            The notebook's path (for logging)
        """
        if self.notary.check_cells(nb):
            self.notary.sign(nb)
        else:
            self.log.warning("Notebook %s is not trusted", path)

    def mark_trusted_cells(self, nb, path=''):
        """Mark cells as trusted if the notebook signature matches.

        Called as a part of loading notebooks.

        Parameters
        ----------
        nb : dict
            The notebook object (in current nbformat)
        path : string
            The notebook's path (for logging)
        """
        trusted = self.notary.check_signature(nb)
        if not trusted:
            self.log.warning("Notebook %s is not trusted", path)
        self.notary.mark_cells(nb, trusted)

    def should_list(self, name):
        """Should this file/directory name be displayed in a listing?"""
        return not any(fnmatch(name, glob) for glob in self.hide_globs)

    # Part 3: Checkpoints API
    def create_checkpoint(self, path):
        """Create a checkpoint."""
        return self.checkpoints.create_checkpoint(self, path)

    def restore_checkpoint(self, checkpoint_id, path):
        """
        Restore a checkpoint.
        """
        self.checkpoints.restore_checkpoint(self, checkpoint_id, path)

    def list_checkpoints(self, path):
        return self.checkpoints.list_checkpoints(path)

    def delete_checkpoint(self, checkpoint_id, path):
        return self.checkpoints.delete_checkpoint(checkpoint_id, path)
示例#8
0
class TerminalInteractiveShell(InteractiveShell):
    space_for_menu = Integer(6, help='Number of line at the bottom of the screen '
                                                  'to reserve for the completion menu'
                            ).tag(config=True)

    def _space_for_menu_changed(self, old, new):
        self._update_layout()

    pt_cli = None
    debugger_history = None
    _pt_app = None

    simple_prompt = Bool(_use_simple_prompt,
        help="""Use `raw_input` for the REPL, without completion and prompt colors.

            Useful when controlling IPython as a subprocess, and piping STDIN/OUT/ERR. Known usage are:
            IPython own testing machinery, and emacs inferior-shell integration through elpy.

            This mode default to `True` if the `IPY_TEST_SIMPLE_PROMPT`
            environment variable is set, or the current terminal is not a tty.

            """
            ).tag(config=True)

    @property
    def debugger_cls(self):
        return Pdb if self.simple_prompt else TerminalPdb

    confirm_exit = Bool(True,
        help="""
        Set to confirm when you try to exit IPython with an EOF (Control-D
        in Unix, Control-Z/Enter in Windows). By typing 'exit' or 'quit',
        you can force a direct exit without any confirmation.""",
    ).tag(config=True)

    editing_mode = Unicode('emacs',
        help="Shortcut style to use at the prompt. 'vi' or 'emacs'.",
    ).tag(config=True)

    mouse_support = Bool(False,
        help="Enable mouse support in the prompt"
    ).tag(config=True)

    highlighting_style = Union([Unicode('legacy'), Type(klass=Style)],
        help="""The name or class of a Pygments style to use for syntax
        highlighting: \n %s""" % ', '.join(get_all_styles())
    ).tag(config=True)


    @observe('highlighting_style')
    @observe('colors')
    def _highlighting_style_changed(self, change):
        self.refresh_style()

    def refresh_style(self):
        self._style = self._make_style_from_name_or_cls(self.highlighting_style)


    highlighting_style_overrides = Dict(
        help="Override highlighting format for specific tokens"
    ).tag(config=True)

    true_color = Bool(False,
        help=("Use 24bit colors instead of 256 colors in prompt highlighting. "
              "If your terminal supports true color, the following command "
              "should print 'TRUECOLOR' in orange: "
              "printf \"\\x1b[38;2;255;100;0mTRUECOLOR\\x1b[0m\\n\"")
    ).tag(config=True)

    editor = Unicode(get_default_editor(),
        help="Set the editor used by IPython (default to $EDITOR/vi/notepad)."
    ).tag(config=True)

    prompts_class = Type(Prompts, help='Class used to generate Prompt token for prompt_toolkit').tag(config=True)

    prompts = Instance(Prompts)

    @default('prompts')
    def _prompts_default(self):
        return self.prompts_class(self)

    @observe('prompts')
    def _(self, change):
        self._update_layout()

    @default('displayhook_class')
    def _displayhook_class_default(self):
        return RichPromptDisplayHook

    term_title = Bool(True,
        help="Automatically set the terminal title"
    ).tag(config=True)

    display_completions = Enum(('column', 'multicolumn','readlinelike'),
        help= ( "Options for displaying tab completions, 'column', 'multicolumn', and "
                "'readlinelike'. These options are for `prompt_toolkit`, see "
                "`prompt_toolkit` documentation for more information."
                ),
        default_value='multicolumn').tag(config=True)

    highlight_matching_brackets = Bool(True,
        help="Highlight matching brackets.",
    ).tag(config=True)

    extra_open_editor_shortcuts = Bool(False,
        help="Enable vi (v) or Emacs (C-X C-E) shortcuts to open an external editor. "
             "This is in addition to the F2 binding, which is always enabled."
    ).tag(config=True)

    @observe('term_title')
    def init_term_title(self, change=None):
        # Enable or disable the terminal title.
        if self.term_title:
            toggle_set_term_title(True)
            set_term_title('IPython: ' + abbrev_cwd())
        else:
            toggle_set_term_title(False)

    def init_display_formatter(self):
        super(TerminalInteractiveShell, self).init_display_formatter()
        # terminal only supports plain text
        self.display_formatter.active_types = ['text/plain']
        # disable `_ipython_display_`
        self.display_formatter.ipython_display_formatter.enabled = False

    def init_prompt_toolkit_cli(self):
        if self.simple_prompt:
            # Fall back to plain non-interactive output for tests.
            # This is very limited, and only accepts a single line.
            def prompt():
                isp = self.input_splitter
                prompt_text = "".join(x[1] for x in self.prompts.in_prompt_tokens())
                prompt_continuation = "".join(x[1] for x in self.prompts.continuation_prompt_tokens())
                while isp.push_accepts_more():
                    line = cast_unicode_py2(input(prompt_text))
                    isp.push(line)
                    prompt_text = prompt_continuation
                return isp.source_reset()
            self.prompt_for_code = prompt
            return

        # Set up keyboard shortcuts
        kbmanager = KeyBindingManager.for_prompt(
            enable_open_in_editor=self.extra_open_editor_shortcuts,
        )
        register_ipython_shortcuts(kbmanager.registry, self)

        # Pre-populate history from IPython's history database
        history = InMemoryHistory()
        last_cell = u""
        for __, ___, cell in self.history_manager.get_tail(self.history_load_length,
                                                        include_latest=True):
            # Ignore blank lines and consecutive duplicates
            cell = cell.rstrip()
            if cell and (cell != last_cell):
                history.append(cell)
                last_cell = cell

        self._style = self._make_style_from_name_or_cls(self.highlighting_style)
        self.style = DynamicStyle(lambda: self._style)

        editing_mode = getattr(EditingMode, self.editing_mode.upper())

        def patch_stdout(**kwargs):
            return self.pt_cli.patch_stdout_context(**kwargs)

        self._pt_app = create_prompt_application(
                            editing_mode=editing_mode,
                            key_bindings_registry=kbmanager.registry,
                            history=history,
                            completer=IPythonPTCompleter(shell=self,
                                                    patch_stdout=patch_stdout),
                            enable_history_search=True,
                            style=self.style,
                            mouse_support=self.mouse_support,
                            **self._layout_options()
        )
        self._eventloop = create_eventloop(self.inputhook)
        self.pt_cli = CommandLineInterface(
            self._pt_app, eventloop=self._eventloop,
            output=create_output(true_color=self.true_color))

    def _make_style_from_name_or_cls(self, name_or_cls):
        """
        Small wrapper that make an IPython compatible style from a style name

        We need that to add style for prompt ... etc.
        """
        style_overrides = {}
        if name_or_cls == 'legacy':
            legacy = self.colors.lower()
            if legacy == 'linux':
                style_cls = get_style_by_name('monokai')
                style_overrides = _style_overrides_linux
            elif legacy == 'lightbg':
                style_overrides = _style_overrides_light_bg
                style_cls = get_style_by_name('pastie')
            elif legacy == 'neutral':
                # The default theme needs to be visible on both a dark background
                # and a light background, because we can't tell what the terminal
                # looks like. These tweaks to the default theme help with that.
                style_cls = get_style_by_name('default')
                style_overrides.update({
                    Token.Number: '#007700',
                    Token.Operator: 'noinherit',
                    Token.String: '#BB6622',
                    Token.Name.Function: '#2080D0',
                    Token.Name.Class: 'bold #2080D0',
                    Token.Name.Namespace: 'bold #2080D0',
                    Token.Prompt: '#009900',
                    Token.PromptNum: '#00ff00 bold',
                    Token.OutPrompt: '#990000',
                    Token.OutPromptNum: '#ff0000 bold',
                })

                # Hack: Due to limited color support on the Windows console
                # the prompt colors will be wrong without this
                if os.name == 'nt':
                    style_overrides.update({
                        Token.Prompt: '#ansidarkgreen',
                        Token.PromptNum: '#ansigreen bold',
                        Token.OutPrompt: '#ansidarkred',
                        Token.OutPromptNum: '#ansired bold',
                    })
            elif legacy =='nocolor':
                style_cls=_NoStyle
                style_overrides = {}
            else :
                raise ValueError('Got unknown colors: ', legacy)
        else :
            if isinstance(name_or_cls, string_types):
                style_cls = get_style_by_name(name_or_cls)
            else:
                style_cls = name_or_cls
            style_overrides = {
                Token.Prompt: '#009900',
                Token.PromptNum: '#00ff00 bold',
                Token.OutPrompt: '#990000',
                Token.OutPromptNum: '#ff0000 bold',
            }
        style_overrides.update(self.highlighting_style_overrides)
        style = PygmentsStyle.from_defaults(pygments_style_cls=style_cls,
                                            style_dict=style_overrides)

        return style

    def _layout_options(self):
        """
        Return the current layout option for the current Terminal InteractiveShell
        """
        return {
                'lexer':IPythonPTLexer(),
                'reserve_space_for_menu':self.space_for_menu,
                'get_prompt_tokens':self.prompts.in_prompt_tokens,
                'get_continuation_tokens':self.prompts.continuation_prompt_tokens,
                'multiline':True,
                'display_completions_in_columns': (self.display_completions == 'multicolumn'),

                # Highlight matching brackets, but only when this setting is
                # enabled, and only when the DEFAULT_BUFFER has the focus.
                'extra_input_processors': [ConditionalProcessor(
                        processor=HighlightMatchingBracketProcessor(chars='[](){}'),
                        filter=HasFocus(DEFAULT_BUFFER) & ~IsDone() &
                            Condition(lambda cli: self.highlight_matching_brackets))],
                }

    def _update_layout(self):
        """
        Ask for a re computation of the application layout, if for example ,
        some configuration options have changed.
        """
        if self._pt_app:
            self._pt_app.layout = create_prompt_layout(**self._layout_options())

    def prompt_for_code(self):
        document = self.pt_cli.run(
            pre_run=self.pre_prompt, reset_current_buffer=True)
        return document.text

    def enable_win_unicode_console(self):
        if sys.version_info >= (3, 6):
            # Since PEP 528, Python uses the unicode APIs for the Windows
            # console by default, so WUC shouldn't be needed.
            return

        import win_unicode_console

        if PY3:
            win_unicode_console.enable()
        else:
            # https://github.com/ipython/ipython/issues/9768
            from win_unicode_console.streams import (TextStreamWrapper,
                                 stdout_text_transcoded, stderr_text_transcoded)

            class LenientStrStreamWrapper(TextStreamWrapper):
                def write(self, s):
                    if isinstance(s, bytes):
                        s = s.decode(self.encoding, 'replace')

                    self.base.write(s)

            stdout_text_str = LenientStrStreamWrapper(stdout_text_transcoded)
            stderr_text_str = LenientStrStreamWrapper(stderr_text_transcoded)

            win_unicode_console.enable(stdout=stdout_text_str,
                                       stderr=stderr_text_str)

    def init_io(self):
        if sys.platform not in {'win32', 'cli'}:
            return

        self.enable_win_unicode_console()

        import colorama
        colorama.init()

        # For some reason we make these wrappers around stdout/stderr.
        # For now, we need to reset them so all output gets coloured.
        # https://github.com/ipython/ipython/issues/8669
        # io.std* are deprecated, but don't show our own deprecation warnings
        # during initialization of the deprecated API.
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', DeprecationWarning)
            io.stdout = io.IOStream(sys.stdout)
            io.stderr = io.IOStream(sys.stderr)

    def init_magics(self):
        super(TerminalInteractiveShell, self).init_magics()
        self.register_magics(TerminalMagics)

    def init_alias(self):
        # The parent class defines aliases that can be safely used with any
        # frontend.
        super(TerminalInteractiveShell, self).init_alias()

        # Now define aliases that only make sense on the terminal, because they
        # need direct access to the console in a way that we can't emulate in
        # GUI or web frontend
        if os.name == 'posix':
            for cmd in ['clear', 'more', 'less', 'man']:
                self.alias_manager.soft_define_alias(cmd, cmd)


    def __init__(self, *args, **kwargs):
        super(TerminalInteractiveShell, self).__init__(*args, **kwargs)
        self.init_prompt_toolkit_cli()
        self.init_term_title()
        self.keep_running = True

        self.debugger_history = InMemoryHistory()

    def ask_exit(self):
        self.keep_running = False

    rl_next_input = None

    def pre_prompt(self):
        if self.rl_next_input:
            # We can't set the buffer here, because it will be reset just after
            # this. Adding a callable to pre_run_callables does what we need
            # after the buffer is reset.
            s = cast_unicode_py2(self.rl_next_input)
            def set_doc():
                self.pt_cli.application.buffer.document = Document(s)
            if hasattr(self.pt_cli, 'pre_run_callables'):
                self.pt_cli.pre_run_callables.append(set_doc)
            else:
                # Older version of prompt_toolkit; it's OK to set the document
                # directly here.
                set_doc()
            self.rl_next_input = None

    def interact(self, display_banner=DISPLAY_BANNER_DEPRECATED):

        if display_banner is not DISPLAY_BANNER_DEPRECATED:
            warn('interact `display_banner` argument is deprecated since IPython 5.0. Call `show_banner()` if needed.', DeprecationWarning, stacklevel=2)

        self.keep_running = True
        while self.keep_running:
            print(self.separate_in, end='')

            try:
                code = self.prompt_for_code()
            except EOFError:
                if (not self.confirm_exit) \
                        or self.ask_yes_no('Do you really want to exit ([y]/n)?','y','n'):
                    self.ask_exit()

            else:
                if code:
                    self.run_cell(code, store_history=True)

    def mainloop(self, display_banner=DISPLAY_BANNER_DEPRECATED):
        # An extra layer of protection in case someone mashing Ctrl-C breaks
        # out of our internal code.
        if display_banner is not DISPLAY_BANNER_DEPRECATED:
            warn('mainloop `display_banner` argument is deprecated since IPython 5.0. Call `show_banner()` if needed.', DeprecationWarning, stacklevel=2)
        while True:
            try:
                self.interact()
                break
            except KeyboardInterrupt as e:
                print("\n%s escaped interact()\n" % type(e).__name__)
            finally:
                # An interrupt during the eventloop will mess up the
                # internal state of the prompt_toolkit library.
                # Stopping the eventloop fixes this, see
                # https://github.com/ipython/ipython/pull/9867
                if hasattr(self, '_eventloop'):
                    self._eventloop.stop()

    _inputhook = None
    def inputhook(self, context):
        if self._inputhook is not None:
            self._inputhook(context)

    active_eventloop = None
    def enable_gui(self, gui=None):
        if gui:
            self.active_eventloop, self._inputhook =\
                get_inputhook_name_and_func(gui)
        else:
            self.active_eventloop = self._inputhook = None

    # Run !system commands directly, not through pipes, so terminal programs
    # work correctly.
    system = InteractiveShell.system_raw

    def auto_rewrite_input(self, cmd):
        """Overridden from the parent class to use fancy rewriting prompt"""
        if not self.show_rewritten_input:
            return

        tokens = self.prompts.rewrite_prompt_tokens()
        if self.pt_cli:
            self.pt_cli.print_tokens(tokens)
            print(cmd)
        else:
            prompt = ''.join(s for t, s in tokens)
            print(prompt, cmd, sep='')

    _prompts_before = None
    def switch_doctest_mode(self, mode):
        """Switch prompts to classic for %doctest_mode"""
        if mode:
            self._prompts_before = self.prompts
            self.prompts = ClassicPrompts(self)
        elif self._prompts_before:
            self.prompts = self._prompts_before
            self._prompts_before = None
        self._update_layout()
示例#9
0
class IPythonKernel(KernelBase):
    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
                     allow_none=True)
    shell_class = Type(ZMQInteractiveShell)

    use_experimental_completions = Bool(
        True,
        help=
        "Set this flag to False to deactivate the use of experimental IPython completion APIs.",
    ).tag(config=True)

    debugpy_stream = Instance(ZMQStream, allow_none=True)

    user_module = Any()

    @observe('user_module')
    @observe_compat
    def _user_module_changed(self, change):
        if self.shell is not None:
            self.shell.user_module = change['new']

    user_ns = Instance(dict, args=None, allow_none=True)

    @observe('user_ns')
    @observe_compat
    def _user_ns_changed(self, change):
        if self.shell is not None:
            self.shell.user_ns = change['new']
            self.shell.init_user_ns()

    # A reference to the Python builtin 'raw_input' function.
    # (i.e., __builtin__.raw_input for Python 2.7, builtins.input for Python 3)
    _sys_raw_input = Any()
    _sys_eval_input = Any()

    def __init__(self, **kwargs):
        super(IPythonKernel, self).__init__(**kwargs)

        # Initialize the Debugger
        self.debugger = Debugger(self.log, self.debugpy_stream,
                                 self._publish_debug_event,
                                 self.debug_shell_socket, self.session)

        # Initialize the InteractiveShell subclass
        self.shell = self.shell_class.instance(
            parent=self,
            profile_dir=self.profile_dir,
            user_module=self.user_module,
            user_ns=self.user_ns,
            kernel=self,
            compiler_class=XCachingCompiler,
        )
        self.shell.displayhook.session = self.session
        self.shell.displayhook.pub_socket = self.iopub_socket
        self.shell.displayhook.topic = self._topic('execute_result')
        self.shell.display_pub.session = self.session
        self.shell.display_pub.pub_socket = self.iopub_socket

        self.comm_manager = CommManager(parent=self, kernel=self)

        self.shell.configurables.append(self.comm_manager)
        comm_msg_types = ['comm_open', 'comm_msg', 'comm_close']
        for msg_type in comm_msg_types:
            self.shell_handlers[msg_type] = getattr(self.comm_manager,
                                                    msg_type)

        if _use_appnope() and self._darwin_app_nap:
            # Disable app-nap as the kernel is not a gui but can have guis
            import appnope
            appnope.nope()

    help_links = List([
        {
            'text': "Python Reference",
            'url': "https://docs.python.org/%i.%i" % sys.version_info[:2],
        },
        {
            'text': "IPython Reference",
            'url': "https://ipython.org/documentation.html",
        },
        {
            'text': "NumPy Reference",
            'url': "https://docs.scipy.org/doc/numpy/reference/",
        },
        {
            'text': "SciPy Reference",
            'url': "https://docs.scipy.org/doc/scipy/reference/",
        },
        {
            'text': "Matplotlib Reference",
            'url': "https://matplotlib.org/contents.html",
        },
        {
            'text': "SymPy Reference",
            'url': "http://docs.sympy.org/latest/index.html",
        },
        {
            'text': "pandas Reference",
            'url': "https://pandas.pydata.org/pandas-docs/stable/",
        },
    ]).tag(config=True)

    # Kernel info fields
    implementation = 'ipython'
    implementation_version = release.version
    language_info = {
        'name': 'python',
        'version': sys.version.split()[0],
        'mimetype': 'text/x-python',
        'codemirror_mode': {
            'name': 'ipython',
            'version': sys.version_info[0]
        },
        'pygments_lexer': 'ipython%d' % 3,
        'nbconvert_exporter': 'python',
        'file_extension': '.py'
    }

    def dispatch_debugpy(self, msg):
        # The first frame is the socket id, we can drop it
        frame = msg[1].bytes.decode('utf-8')
        self.log.debug("Debugpy received: %s", frame)
        self.debugger.tcp_client.receive_dap_frame(frame)

    @property
    def banner(self):
        return self.shell.banner

    def start(self):
        self.shell.exit_now = False
        if self.debugpy_stream is None:
            self.log.warning(
                "debugpy_stream undefined, debugging will not be enabled")
        else:
            self.debugpy_stream.on_recv(self.dispatch_debugpy, copy=False)
        super(IPythonKernel, self).start()

    def set_parent(self, ident, parent, channel='shell'):
        """Overridden from parent to tell the display hook and output streams
        about the parent message.
        """
        super(IPythonKernel, self).set_parent(ident, parent, channel)
        if channel == 'shell':
            self.shell.set_parent(parent)

    def init_metadata(self, parent):
        """Initialize metadata.

        Run at the beginning of each execution request.
        """
        md = super(IPythonKernel, self).init_metadata(parent)
        # FIXME: remove deprecated ipyparallel-specific code
        # This is required for ipyparallel < 5.0
        md.update({
            'dependencies_met': True,
            'engine': self.ident,
        })
        return md

    def finish_metadata(self, parent, metadata, reply_content):
        """Finish populating metadata.

        Run after completing an execution request.
        """
        # FIXME: remove deprecated ipyparallel-specific code
        # This is required by ipyparallel < 5.0
        metadata['status'] = reply_content['status']
        if reply_content['status'] == 'error' and reply_content[
                'ename'] == 'UnmetDependency':
            metadata['dependencies_met'] = False

        return metadata

    def _forward_input(self, allow_stdin=False):
        """Forward raw_input and getpass to the current frontend.

        via input_request
        """
        self._allow_stdin = allow_stdin

        self._sys_raw_input = builtins.input
        builtins.input = self.raw_input

        self._save_getpass = getpass.getpass
        getpass.getpass = self.getpass

    def _restore_input(self):
        """Restore raw_input, getpass"""
        builtins.input = self._sys_raw_input

        getpass.getpass = self._save_getpass

    @property
    def execution_count(self):
        return self.shell.execution_count

    @execution_count.setter
    def execution_count(self, value):
        # Ignore the incrementing done by KernelBase, in favour of our shell's
        # execution counter.
        pass

    @contextmanager
    def _cancel_on_sigint(self, future):
        """ContextManager for capturing SIGINT and cancelling a future

        SIGINT raises in the event loop when running async code,
        but we want it to halt a coroutine.

        Ideally, it would raise KeyboardInterrupt,
        but this turns it into a CancelledError.
        At least it gets a decent traceback to the user.
        """
        sigint_future = asyncio.Future()

        # whichever future finishes first,
        # cancel the other one
        def cancel_unless_done(f, _ignored):
            if f.cancelled() or f.done():
                return
            f.cancel()

        # when sigint finishes,
        # abort the coroutine with CancelledError
        sigint_future.add_done_callback(partial(cancel_unless_done, future))
        # when the main future finishes,
        # stop watching for SIGINT events
        future.add_done_callback(partial(cancel_unless_done, sigint_future))

        def handle_sigint(*args):
            def set_sigint_result():
                if sigint_future.cancelled() or sigint_future.done():
                    return
                sigint_future.set_result(1)

            # use add_callback for thread safety
            self.io_loop.add_callback(set_sigint_result)

        # set the custom sigint hander during this context
        save_sigint = signal.signal(signal.SIGINT, handle_sigint)
        try:
            yield
        finally:
            # restore the previous sigint handler
            signal.signal(signal.SIGINT, save_sigint)

    async def do_execute(self,
                         code,
                         silent,
                         store_history=True,
                         user_expressions=None,
                         allow_stdin=False):
        shell = self.shell  # we'll need this a lot here

        self._forward_input(allow_stdin)

        reply_content = {}
        if hasattr(shell, 'run_cell_async') and hasattr(
                shell, 'should_run_async'):
            run_cell = shell.run_cell_async
            should_run_async = shell.should_run_async
        else:
            should_run_async = lambda cell: False

            # older IPython,
            # use blocking run_cell and wrap it in coroutine
            async def run_cell(*args, **kwargs):
                return shell.run_cell(*args, **kwargs)

        try:

            # default case: runner is asyncio and asyncio is already running
            # TODO: this should check every case for "are we inside the runner",
            # not just asyncio
            preprocessing_exc_tuple = None
            try:
                transformed_cell = self.shell.transform_cell(code)
            except Exception:
                transformed_cell = code
                preprocessing_exc_tuple = sys.exc_info()

            if (_asyncio_runner and shell.loop_runner is _asyncio_runner
                    and asyncio.get_event_loop().is_running()
                    and should_run_async(
                        code,
                        transformed_cell=transformed_cell,
                        preprocessing_exc_tuple=preprocessing_exc_tuple)):
                coro = run_cell(
                    code,
                    store_history=store_history,
                    silent=silent,
                    transformed_cell=transformed_cell,
                    preprocessing_exc_tuple=preprocessing_exc_tuple)
                coro_future = asyncio.ensure_future(coro)

                with self._cancel_on_sigint(coro_future):
                    res = None
                    try:
                        res = await coro_future
                    finally:
                        shell.events.trigger('post_execute')
                        if not silent:
                            shell.events.trigger('post_run_cell', res)
            else:
                # runner isn't already running,
                # make synchronous call,
                # letting shell dispatch to loop runners
                res = shell.run_cell(code,
                                     store_history=store_history,
                                     silent=silent)
        finally:
            self._restore_input()

        if res.error_before_exec is not None:
            err = res.error_before_exec
        else:
            err = res.error_in_exec

        if res.success:
            reply_content['status'] = 'ok'
        else:
            reply_content['status'] = 'error'

            reply_content.update({
                'traceback': shell._last_traceback or [],
                'ename': str(type(err).__name__),
                'evalue': str(err),
            })

            # FIXME: deprecated piece for ipyparallel (remove in 5.0):
            e_info = dict(engine_uuid=self.ident,
                          engine_id=self.int_id,
                          method='execute')
            reply_content['engine_info'] = e_info

        # Return the execution counter so clients can display prompts
        reply_content['execution_count'] = shell.execution_count - 1

        if 'traceback' in reply_content:
            self.log.info("Exception in execute request:\n%s",
                          '\n'.join(reply_content['traceback']))

        # At this point, we can tell whether the main code execution succeeded
        # or not.  If it did, we proceed to evaluate user_expressions
        if reply_content['status'] == 'ok':
            reply_content['user_expressions'] = \
                         shell.user_expressions(user_expressions or {})
        else:
            # If there was an error, don't even try to compute expressions
            reply_content['user_expressions'] = {}

        # Payloads should be retrieved regardless of outcome, so we can both
        # recover partial output (that could have been generated early in a
        # block, before an error) and always clear the payload system.
        reply_content['payload'] = shell.payload_manager.read_payload()
        # Be aggressive about clearing the payload because we don't want
        # it to sit in memory until the next execute_request comes in.
        shell.payload_manager.clear_payload()

        return reply_content

    def do_complete(self, code, cursor_pos):
        if _use_experimental_60_completion and self.use_experimental_completions:
            return self._experimental_do_complete(code, cursor_pos)

        # FIXME: IPython completers currently assume single line,
        # but completion messages give multi-line context
        # For now, extract line from cell, based on cursor_pos:
        if cursor_pos is None:
            cursor_pos = len(code)
        line, offset = line_at_cursor(code, cursor_pos)
        line_cursor = cursor_pos - offset

        txt, matches = self.shell.complete('', line, line_cursor)
        return {
            'matches': matches,
            'cursor_end': cursor_pos,
            'cursor_start': cursor_pos - len(txt),
            'metadata': {},
            'status': 'ok'
        }

    async def do_debug_request(self, msg):
        return await self.debugger.process_request(msg)

    def _experimental_do_complete(self, code, cursor_pos):
        """
        Experimental completions from IPython, using Jedi.
        """
        if cursor_pos is None:
            cursor_pos = len(code)
        with _provisionalcompleter():
            raw_completions = self.shell.Completer.completions(
                code, cursor_pos)
            completions = list(_rectify_completions(code, raw_completions))

            comps = []
            for comp in completions:
                comps.append(
                    dict(
                        start=comp.start,
                        end=comp.end,
                        text=comp.text,
                        type=comp.type,
                    ))

        if completions:
            s = completions[0].start
            e = completions[0].end
            matches = [c.text for c in completions]
        else:
            s = cursor_pos
            e = cursor_pos
            matches = []

        return {
            'matches': matches,
            'cursor_end': e,
            'cursor_start': s,
            'metadata': {
                _EXPERIMENTAL_KEY_NAME: comps
            },
            'status': 'ok'
        }

    def do_inspect(self, code, cursor_pos, detail_level=0):
        name = token_at_cursor(code, cursor_pos)

        reply_content = {'status': 'ok'}
        reply_content['data'] = {}
        reply_content['metadata'] = {}
        try:
            reply_content['data'].update(
                self.shell.object_inspect_mime(name,
                                               detail_level=detail_level))
            if not self.shell.enable_html_pager:
                reply_content['data'].pop('text/html')
            reply_content['found'] = True
        except KeyError:
            reply_content['found'] = False

        return reply_content

    def do_history(self,
                   hist_access_type,
                   output,
                   raw,
                   session=0,
                   start=0,
                   stop=None,
                   n=None,
                   pattern=None,
                   unique=False):
        if hist_access_type == 'tail':
            hist = self.shell.history_manager.get_tail(n,
                                                       raw=raw,
                                                       output=output,
                                                       include_latest=True)

        elif hist_access_type == 'range':
            hist = self.shell.history_manager.get_range(session,
                                                        start,
                                                        stop,
                                                        raw=raw,
                                                        output=output)

        elif hist_access_type == 'search':
            hist = self.shell.history_manager.search(pattern,
                                                     raw=raw,
                                                     output=output,
                                                     n=n,
                                                     unique=unique)
        else:
            hist = []

        return {
            'status': 'ok',
            'history': list(hist),
        }

    def do_shutdown(self, restart):
        self.shell.exit_now = True
        return dict(status='ok', restart=restart)

    def do_is_complete(self, code):
        transformer_manager = getattr(self.shell, 'input_transformer_manager',
                                      None)
        if transformer_manager is None:
            # input_splitter attribute is deprecated
            transformer_manager = self.shell.input_splitter
        status, indent_spaces = transformer_manager.check_complete(code)
        r = {'status': status}
        if status == 'incomplete':
            r['indent'] = ' ' * indent_spaces
        return r

    def do_apply(self, content, bufs, msg_id, reply_metadata):
        from .serialize import serialize_object, unpack_apply_message
        shell = self.shell
        try:
            working = shell.user_ns

            prefix = "_" + str(msg_id).replace("-", "") + "_"

            f, args, kwargs = unpack_apply_message(bufs, working, copy=False)

            fname = getattr(f, '__name__', 'f')

            fname = prefix + "f"
            argname = prefix + "args"
            kwargname = prefix + "kwargs"
            resultname = prefix + "result"

            ns = {fname: f, argname: args, kwargname: kwargs, resultname: None}
            # print ns
            working.update(ns)
            code = "%s = %s(*%s,**%s)" % (resultname, fname, argname,
                                          kwargname)
            try:
                exec(code, shell.user_global_ns, shell.user_ns)
                result = working.get(resultname)
            finally:
                for key in ns:
                    working.pop(key)

            result_buf = serialize_object(
                result,
                buffer_threshold=self.session.buffer_threshold,
                item_threshold=self.session.item_threshold,
            )

        except BaseException as e:
            # invoke IPython traceback formatting
            shell.showtraceback()
            reply_content = {
                'traceback': shell._last_traceback or [],
                'ename': str(type(e).__name__),
                'evalue': safe_unicode(e),
            }
            # FIXME: deprecated piece for ipyparallel (remove in 5.0):
            e_info = dict(engine_uuid=self.ident,
                          engine_id=self.int_id,
                          method='apply')
            reply_content['engine_info'] = e_info

            self.send_response(self.iopub_socket,
                               'error',
                               reply_content,
                               ident=self._topic('error'),
                               channel='shell')
            self.log.info("Exception in apply request:\n%s",
                          '\n'.join(reply_content['traceback']))
            result_buf = []
            reply_content['status'] = 'error'
        else:
            reply_content = {'status': 'ok'}

        return reply_content, result_buf

    def do_clear(self):
        self.shell.reset(False)
        return dict(status='ok')
示例#10
0
class TerminalInteractiveShell(InteractiveShell):
    space_for_menu = Integer(
        6,
        help="Number of line at the bottom of the screen "
        "to reserve for the completion menu",
    ).tag(config=True)

    pt_app = None
    debugger_history = None

    simple_prompt = Bool(
        _use_simple_prompt,
        help=
        """Use `raw_input` for the REPL, without completion and prompt colors.

            Useful when controlling IPython as a subprocess, and piping STDIN/OUT/ERR. Known usage are:
            IPython own testing machinery, and emacs inferior-shell integration through elpy.

            This mode default to `True` if the `IPY_TEST_SIMPLE_PROMPT`
            environment variable is set, or the current terminal is not a tty.""",
    ).tag(config=True)

    @property
    def debugger_cls(self):
        return Pdb if self.simple_prompt else TerminalPdb

    confirm_exit = Bool(
        True,
        help="""
        Set to confirm when you try to exit IPython with an EOF (Control-D
        in Unix, Control-Z/Enter in Windows). By typing 'exit' or 'quit',
        you can force a direct exit without any confirmation.""",
    ).tag(config=True)

    editing_mode = Unicode(
        "emacs",
        help="Shortcut style to use at the prompt. 'vi' or 'emacs'.").tag(
            config=True)

    mouse_support = Bool(
        False,
        help=
        "Enable mouse support in the prompt\n(Note: prevents selecting text with the mouse)",
    ).tag(config=True)

    # We don't load the list of styles for the help string, because loading
    # Pygments plugins takes time and can cause unexpected errors.
    highlighting_style = Union(
        [Unicode("legacy"), Type(klass=Style)],
        help="""The name or class of a Pygments style to use for syntax
        highlighting. To see available styles, run `pygmentize -L styles`.""",
    ).tag(config=True)

    @validate("editing_mode")
    def _validate_editing_mode(self, proposal):
        if proposal["value"].lower() == "vim":
            proposal["value"] = "vi"
        elif proposal["value"].lower() == "default":
            proposal["value"] = "emacs"

        if hasattr(EditingMode, proposal["value"].upper()):
            return proposal["value"].lower()

        return self.editing_mode

    @observe("editing_mode")
    def _editing_mode(self, change):
        u_mode = change.new.upper()
        if self.pt_app:
            self.pt_app.editing_mode = u_mode

    @observe("highlighting_style")
    @observe("colors")
    def _highlighting_style_changed(self, change):
        self.refresh_style()

    def refresh_style(self):
        self._style = self._make_style_from_name_or_cls(
            self.highlighting_style)

    highlighting_style_overrides = Dict(
        help="Override highlighting format for specific tokens").tag(
            config=True)

    true_color = Bool(
        False,
        help=("Use 24bit colors instead of 256 colors in prompt highlighting. "
              "If your terminal supports true color, the following command "
              "should print 'TRUECOLOR' in orange: "
              'printf "\\x1b[38;2;255;100;0mTRUECOLOR\\x1b[0m\\n"'),
    ).tag(config=True)

    editor = Unicode(
        get_default_editor(),
        help="Set the editor used by IPython (default to $EDITOR/vi/notepad).",
    ).tag(config=True)

    prompts_class = Type(
        Prompts,
        help="Class used to generate Prompt token for prompt_toolkit").tag(
            config=True)

    prompts = Instance(Prompts)

    @default("prompts")
    def _prompts_default(self):
        return self.prompts_class(self)

    #    @observe('prompts')
    #    def _(self, change):
    #        self._update_layout()

    @default("displayhook_class")
    def _displayhook_class_default(self):
        return RichPromptDisplayHook

    term_title = Bool(
        True, help="Automatically set the terminal title").tag(config=True)

    term_title_format = Unicode(
        "IPython: {cwd}",
        help=
        "Customize the terminal title format.  This is a python format string. "
        + "Available substitutions are: {cwd}.",
    ).tag(config=True)

    display_completions = Enum(
        ("column", "multicolumn", "readlinelike"),
        help=
        ("Options for displaying tab completions, 'column', 'multicolumn', and "
         "'readlinelike'. These options are for `prompt_toolkit`, see "
         "`prompt_toolkit` documentation for more information."),
        default_value="multicolumn",
    ).tag(config=True)

    highlight_matching_brackets = Bool(
        True, help="Highlight matching brackets.").tag(config=True)

    extra_open_editor_shortcuts = Bool(
        False,
        help=
        "Enable vi (v) or Emacs (C-X C-E) shortcuts to open an external editor. "
        "This is in addition to the F2 binding, which is always enabled.",
    ).tag(config=True)

    handle_return = Any(
        None,
        help="Provide an alternative handler to be called when the user presses "
        "Return. This is an advanced option intended for debugging, which "
        "may be changed or removed in later releases.",
    ).tag(config=True)

    enable_history_search = Bool(
        True,
        help="Allows to enable/disable the prompt toolkit history search").tag(
            config=True)

    prompt_includes_vi_mode = Bool(
        True,
        help="Display the current vi mode (when using vi editing mode).").tag(
            config=True)

    @observe("term_title")
    def init_term_title(self, change=None):
        # Enable or disable the terminal title.
        if self.term_title:
            toggle_set_term_title(True)
            set_term_title(self.term_title_format.format(cwd=abbrev_cwd()))
        else:
            toggle_set_term_title(False)

    def init_display_formatter(self):
        super(TerminalInteractiveShell, self).init_display_formatter()
        # terminal only supports plain text
        self.display_formatter.active_types = ["text/plain"]
        # disable `_ipython_display_`
        self.display_formatter.ipython_display_formatter.enabled = False

    def init_prompt_toolkit_cli(self):
        if self.simple_prompt:
            # Fall back to plain non-interactive output for tests.
            # This is very limited.
            def prompt():
                prompt_text = "".join(x[1]
                                      for x in self.prompts.in_prompt_tokens())
                lines = [input(prompt_text)]
                prompt_continuation = "".join(
                    x[1] for x in self.prompts.continuation_prompt_tokens())
                while self.check_complete("\n".join(lines))[0] == "incomplete":
                    lines.append(input(prompt_continuation))
                return "\n".join(lines)

            self.prompt_for_code = prompt
            return

        # Set up keyboard shortcuts
        key_bindings = create_ipython_shortcuts(self)

        # Pre-populate history from IPython's history database
        history = InMemoryHistory()
        last_cell = u""
        for __, ___, cell in self.history_manager.get_tail(
                self.history_load_length, include_latest=True):
            # Ignore blank lines and consecutive duplicates
            cell = cell.rstrip()
            if cell and (cell != last_cell):
                history.append_string(cell)
                last_cell = cell

        self._style = self._make_style_from_name_or_cls(
            self.highlighting_style)
        self.style = DynamicStyle(lambda: self._style)

        editing_mode = getattr(EditingMode, self.editing_mode.upper())

        self.pt_app = PromptSession(
            editing_mode=editing_mode,
            key_bindings=key_bindings,
            history=history,
            completer=IPythonPTCompleter(shell=self),
            enable_history_search=self.enable_history_search,
            style=self.style,
            include_default_pygments_style=False,
            mouse_support=self.mouse_support,
            enable_open_in_editor=self.extra_open_editor_shortcuts,
            color_depth=self.color_depth,
            **self._extra_prompt_options())

    def _make_style_from_name_or_cls(self, name_or_cls):
        """
        Small wrapper that make an IPython compatible style from a style name

        We need that to add style for prompt ... etc.
        """
        style_overrides = {}
        if name_or_cls == "legacy":
            legacy = self.colors.lower()
            if legacy == "linux":
                style_cls = get_style_by_name("monokai")
                style_overrides = _style_overrides_linux
            elif legacy == "lightbg":
                style_overrides = _style_overrides_light_bg
                style_cls = get_style_by_name("pastie")
            elif legacy == "neutral":
                # The default theme needs to be visible on both a dark background
                # and a light background, because we can't tell what the terminal
                # looks like. These tweaks to the default theme help with that.
                style_cls = get_style_by_name("default")
                style_overrides.update({
                    Token.Number:
                    "#007700",
                    Token.Operator:
                    "noinherit",
                    Token.String:
                    "#BB6622",
                    Token.Name.Function:
                    "#2080D0",
                    Token.Name.Class:
                    "bold #2080D0",
                    Token.Name.Namespace:
                    "bold #2080D0",
                    Token.Prompt:
                    "#009900",
                    Token.PromptNum:
                    "#ansibrightgreen bold",
                    Token.OutPrompt:
                    "#990000",
                    Token.OutPromptNum:
                    "#ansibrightred bold",
                })

                # Hack: Due to limited color support on the Windows console
                # the prompt colors will be wrong without this
                if os.name == "nt":
                    style_overrides.update({
                        Token.Prompt: "#ansidarkgreen",
                        Token.PromptNum: "#ansigreen bold",
                        Token.OutPrompt: "#ansidarkred",
                        Token.OutPromptNum: "#ansired bold",
                    })
            elif legacy == "nocolor":
                style_cls = _NoStyle
                style_overrides = {}
            else:
                raise ValueError("Got unknown colors: ", legacy)
        else:
            if isinstance(name_or_cls, str):
                style_cls = get_style_by_name(name_or_cls)
            else:
                style_cls = name_or_cls
            style_overrides = {
                Token.Prompt: "#009900",
                Token.PromptNum: "#ansibrightgreen bold",
                Token.OutPrompt: "#990000",
                Token.OutPromptNum: "#ansibrightred bold",
            }
        style_overrides.update(self.highlighting_style_overrides)
        style = merge_styles([
            style_from_pygments_cls(style_cls),
            style_from_pygments_dict(style_overrides),
        ])

        return style

    @property
    def pt_complete_style(self):
        return {
            "multicolumn": CompleteStyle.MULTI_COLUMN,
            "column": CompleteStyle.COLUMN,
            "readlinelike": CompleteStyle.READLINE_LIKE,
        }[self.display_completions]

    @property
    def color_depth(self):
        return ColorDepth.TRUE_COLOR if self.true_color else None

    def _extra_prompt_options(self):
        """
        Return the current layout option for the current Terminal InteractiveShell
        """
        def get_message():
            return PygmentsTokens(self.prompts.in_prompt_tokens())

        return {
            "complete_in_thread":
            False,
            "lexer":
            IPythonPTLexer(),
            "reserve_space_for_menu":
            self.space_for_menu,
            "message":
            get_message,
            "prompt_continuation":
            (lambda width, lineno, is_soft_wrap: PygmentsTokens(
                self.prompts.continuation_prompt_tokens(width))),
            "multiline":
            True,
            "complete_style":
            self.pt_complete_style,
            # Highlight matching brackets, but only when this setting is
            # enabled, and only when the DEFAULT_BUFFER has the focus.
            "input_processors": [
                ConditionalProcessor(
                    processor=HighlightMatchingBracketProcessor(
                        chars="[](){}"),
                    filter=HasFocus(DEFAULT_BUFFER)
                    & ~IsDone()
                    & Condition(lambda: self.highlight_matching_brackets),
                )
            ],
            "inputhook":
            self.inputhook,
        }

    def prompt_for_code(self):
        if self.rl_next_input:
            default = self.rl_next_input
            self.rl_next_input = None
        else:
            default = ""

        with patch_stdout(raw=True):
            text = self.pt_app.prompt(
                default=default,
                #                pre_run=self.pre_prompt,# reset_current_buffer=True,
                **self._extra_prompt_options())
        return text

    def enable_win_unicode_console(self):
        if sys.version_info >= (3, 6):
            # Since PEP 528, Python uses the unicode APIs for the Windows
            # console by default, so WUC shouldn't be needed.
            return

        import win_unicode_console

        win_unicode_console.enable()

    def init_io(self):
        if sys.platform not in {"win32", "cli"}:
            return

        self.enable_win_unicode_console()

        import colorama

        colorama.init()

        # For some reason we make these wrappers around stdout/stderr.
        # For now, we need to reset them so all output gets coloured.
        # https://github.com/ipython/ipython/issues/8669
        # io.std* are deprecated, but don't show our own deprecation warnings
        # during initialization of the deprecated API.
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", DeprecationWarning)
            io.stdout = io.IOStream(sys.stdout)
            io.stderr = io.IOStream(sys.stderr)

    def init_magics(self):
        super(TerminalInteractiveShell, self).init_magics()
        self.register_magics(TerminalMagics)

    def init_alias(self):
        # The parent class defines aliases that can be safely used with any
        # frontend.
        super(TerminalInteractiveShell, self).init_alias()

        # Now define aliases that only make sense on the terminal, because they
        # need direct access to the console in a way that we can't emulate in
        # GUI or web frontend
        if os.name == "posix":
            for cmd in ("clear", "more", "less", "man"):
                self.alias_manager.soft_define_alias(cmd, cmd)

    def __init__(self, *args, **kwargs):
        super(TerminalInteractiveShell, self).__init__(*args, **kwargs)
        self.init_prompt_toolkit_cli()
        self.init_term_title()
        self.keep_running = True

        self.debugger_history = InMemoryHistory()

    def ask_exit(self):
        self.keep_running = False

    rl_next_input = None

    def interact(self, display_banner=DISPLAY_BANNER_DEPRECATED):

        if display_banner is not DISPLAY_BANNER_DEPRECATED:
            warn(
                "interact `display_banner` argument is deprecated since IPython 5.0. Call `show_banner()` if needed.",
                DeprecationWarning,
                stacklevel=2,
            )

        self.keep_running = True
        while self.keep_running:
            print(self.separate_in, end="")

            try:
                code = self.prompt_for_code()
            except EOFError:
                if (not self.confirm_exit) or self.ask_yes_no(
                        "Do you really want to exit ([y]/n)?", "y", "n"):
                    self.ask_exit()

            else:
                if code:
                    self.run_cell(code, store_history=True)

    def mainloop(self, display_banner=DISPLAY_BANNER_DEPRECATED):
        # An extra layer of protection in case someone mashing Ctrl-C breaks
        # out of our internal code.
        if display_banner is not DISPLAY_BANNER_DEPRECATED:
            warn(
                "mainloop `display_banner` argument is deprecated since IPython 5.0. Call `show_banner()` if needed.",
                DeprecationWarning,
                stacklevel=2,
            )
        while True:
            try:
                self.interact()
                break
            except KeyboardInterrupt as e:
                print("\n%s escaped interact()\n" % type(e).__name__)
            finally:
                # An interrupt during the eventloop will mess up the
                # internal state of the prompt_toolkit library.
                # Stopping the eventloop fixes this, see
                # https://github.com/ipython/ipython/pull/9867
                if hasattr(self, "_eventloop"):
                    self._eventloop.stop()

    _inputhook = None

    def inputhook(self, context):
        if self._inputhook is not None:
            self._inputhook(context)

    active_eventloop = None

    def enable_gui(self, gui=None):
        if gui:
            self.active_eventloop, self._inputhook = get_inputhook_name_and_func(
                gui)
        else:
            self.active_eventloop = self._inputhook = None

    # Run !system commands directly, not through pipes, so terminal programs
    # work correctly.
    system = InteractiveShell.system_raw

    def auto_rewrite_input(self, cmd):
        """Overridden from the parent class to use fancy rewriting prompt"""
        if not self.show_rewritten_input:
            return

        tokens = self.prompts.rewrite_prompt_tokens()
        if self.pt_app:
            print_formatted_text(PygmentsTokens(tokens),
                                 end="",
                                 style=self.pt_app.app.style)
            print(cmd)
        else:
            prompt = "".join(s for t, s in tokens)
            print(prompt, cmd, sep="")

    _prompts_before = None

    def switch_doctest_mode(self, mode):
        """Switch prompts to classic for %doctest_mode"""
        if mode:
            self._prompts_before = self.prompts
            self.prompts = ClassicPrompts(self)
        elif self._prompts_before:
            self.prompts = self._prompts_before
            self._prompts_before = None
示例#11
0
class IPController(BaseParallelApplication):

    name = 'ipcontroller'
    description = _description
    examples = _examples
    classes = [
        ProfileDir,
        Session,
        Hub,
        TaskScheduler,
        HeartMonitor,
        DictDB,
    ] + real_dbs
    _deprecated_classes = ("HubFactory", "IPControllerApp")

    # change default to True
    auto_create = Bool(
        True,
        config=True,
        help="""Whether to create profile dir if it doesn't exist.""")

    reuse_files = Bool(
        False,
        config=True,
        help="""Whether to reuse existing json connection files.
        If False, connection files will be removed on a clean exit.
        """,
    )
    restore_engines = Bool(
        False,
        config=True,
        help="""Reload engine state from JSON file
        """,
    )
    ssh_server = Unicode(
        '',
        config=True,
        help="""ssh url for clients to use when connecting to the Controller
        processes. It should be of the form: [user@]server[:port]. The
        Controller's listening addresses must be accessible from the ssh server""",
    )
    engine_ssh_server = Unicode(
        '',
        config=True,
        help="""ssh url for engines to use when connecting to the Controller
        processes. It should be of the form: [user@]server[:port]. The
        Controller's listening addresses must be accessible from the ssh server""",
    )
    location = Unicode(
        socket.gethostname(),
        config=True,
        help=
        """The external IP or domain name of the Controller, used for disambiguating
        engine and client connections.""",
    )

    use_threads = Bool(
        False,
        config=True,
        help='Use threads instead of processes for the schedulers')

    engine_json_file = Unicode(
        'ipcontroller-engine.json',
        config=True,
        help="JSON filename where engine connection info will be stored.",
    )
    client_json_file = Unicode(
        'ipcontroller-client.json',
        config=True,
        help="JSON filename where client connection info will be stored.",
    )

    @observe('cluster_id')
    def _cluster_id_changed(self, change):
        base = 'ipcontroller'
        if change.new:
            base = f"{base}-{change.new}"
        self.engine_json_file = f"{base}-engine.json"
        self.client_json_file = f"{base}-client.json"

    enable_curve = Bool(
        False,
        config=True,
        help="""Enable CurveZMQ encryption and authentication

        Caution: known to have issues on platforms with getrandom
        """,
    )

    @default("enable_curve")
    def _default_enable_curve(self):
        enabled = os.environ.get("IPP_ENABLE_CURVE", "") == "1"
        if enabled:
            self._ensure_curve_keypair()
            # disable redundant digest-key, CurveZMQ protects against replays
            if 'key' not in self.config.Session:
                self.config.Session.key = b''
        return enabled

    @observe("enable_curve")
    def _enable_curve_changed(self, change):
        if change.new:
            self._ensure_curve_keypair()
            # disable redundant digest-key, CurveZMQ protects against replays
            if 'key' not in self.config.Session:
                self.config.Session.key = b''

    def _ensure_curve_keypair(self):
        if not self.curve_secretkey or not self.curve_publickey:
            self.log.info("Generating new CURVE credentials")
            self.curve_publickey, self.curve_secretkey = zmq.curve_keypair()

    curve_secretkey = Bytes(
        config=True,
        help="The CurveZMQ secret key for the controller",
    )
    curve_publickey = Bytes(
        config=True,
        help="""The CurveZMQ public key for the controller.

        Engines and clients use this for the server key.
        """,
    )

    @default("curve_secretkey")
    def _default_curve_secretkey(self):
        return os.environ.get("IPP_CURVE_SECRETKEY", "").encode("ascii")

    @default("curve_publickey")
    def _default_curve_publickey(self):
        return os.environ.get("IPP_CURVE_PUBLICKEY", "").encode("ascii")

    @validate("curve_publickey", "curve_secretkey")
    def _cast_bytes(self, proposal):
        if isinstance(proposal.value, str):
            return proposal.value.encode("ascii")
        return proposal.value

    # internal
    children = List()
    mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')

    @observe('use_threads')
    def _use_threads_changed(self, change):
        self.mq_class = 'zmq.devices.{}MonitoredQueue'.format(
            'Thread' if change['new'] else 'Process')

    write_connection_files = Bool(
        True,
        help="""Whether to write connection files to disk.
        True in all cases other than runs with `reuse_files=True` *after the first*
        """,
    )

    aliases = Dict(aliases)
    flags = Dict(flags)

    broadcast_scheduler_depth = Integer(
        1,
        config=True,
        help="Depth of spanning tree schedulers",
    )
    number_of_leaf_schedulers = Integer()
    number_of_broadcast_schedulers = Integer()
    number_of_non_leaf_schedulers = Integer()

    @default('number_of_leaf_schedulers')
    def get_number_of_leaf_schedulers(self):
        return 2**self.broadcast_scheduler_depth

    @default('number_of_broadcast_schedulers')
    def get_number_of_broadcast_schedulers(self):
        return 2 * self.number_of_leaf_schedulers - 1

    @default('number_of_non_leaf_schedulers')
    def get_number_of_non_leaf_schedulers(self):
        return self.number_of_broadcast_schedulers - self.number_of_leaf_schedulers

    ports = PortList(
        Integer(min=1, max=65536),
        config=True,
        help="""
        Pool of ports to use for the controller.

        For example:

            ipcontroller --ports 10101-10120

        This list will be consumed to populate the ports
        to be used when binding controller sockets
        for engines, clients, or internal connections.

        The number of sockets needed depends on scheduler depth,
        but is at least 14 (16 by default)

        If more ports than are defined here are needed,
        random ports will be selected.

        Can be specified as a list or string expressing a range

        See also engine_ports and client_ports
        """,
    )

    engine_ports = PortList(
        config=True,
        help="""
        Pool of ports to use for engine connections

        This list will be consumed to populate the ports
        to be used when binding controller sockets used for engine connections

        Can be specified as a list or string expressing a range

        If this list is exhausted, the common `ports` pool will be consumed.

        See also ports and client_ports
        """,
    )

    client_ports = PortList(
        config=True,
        help="""
        Pool of ports to use for client connections

        This list will be consumed to populate the ports
        to be used when binding controller sockets used for client connections

        Can be specified as a list or string expressing a range

        If this list is empty or exhausted,
        the common `ports` pool will be consumed.

        See also ports and engine_ports
        """,
    )

    # observe consumption of pools
    port_index = engine_port_index = client_port_index = 0

    # port-pairs for schedulers:
    hb = Tuple(
        Integer(),
        Integer(),
        config=True,
        help="""DEPRECATED: use ports""",
    )

    mux = Tuple(
        Integer(),
        Integer(),
        config=True,
        help="""DEPRECATED: use ports""",
    )

    task = Tuple(
        Integer(),
        Integer(),
        config=True,
        help="""DEPRECATED: use ports""",
    )

    control = Tuple(
        Integer(),
        Integer(),
        config=True,
        help="""DEPRECATED: use ports""",
    )

    iopub = Tuple(
        Integer(),
        Integer(),
        config=True,
        help="""DEPRECATED: use ports""",
    )

    @observe("iopub", "control", "task", "mux")
    def _scheduler_ports_assigned(self, change):
        self.log.warning(
            f"Setting {self.__class__.__name__}.{change.name} = {change.new!r} is deprecated in IPython Parallel 7."
            " Use IPController.ports config instead.")
        self.client_ports.append(change.new[0])
        self.engine_ports.append(change.new[0])

    @observe("hb")
    def _hb_ports_assigned(self, change):
        self.log.warning(
            f"Setting {self.__class__.__name__}.{change.name} = {change.new!r} is deprecated in IPython Parallel 7."
            " Use IPController.engine_ports config instead.")
        self.engine_ports.extend(change.new)

    # single ports:
    mon_port = Integer(config=True, help="""DEPRECATED: use ports""")

    notifier_port = Integer(
        config=True, help="""DEPRECATED: use ports""").tag(port_pool="client")

    regport = Integer(config=True,
                      help="DEPRECATED: use ports").tag(port_pool="engine")

    @observe("regport", "notifier_port", "mon_port")
    def _port_assigned(self, change):
        self.log.warning(
            f"Setting {self.__class__.__name__}.{change.name} = {change.new!r} is deprecated in IPython Parallel 7."
            " Use IPController.ports config instead.")
        trait = self.traits()[change.name]
        pool_name = trait.metadata.get("port_pool")
        if pool_name == 'engine':
            self.engine_ports.append(change.new)
        elif pool_name == 'client':
            self.client_ports.append(change.new)
        else:
            self.ports.append(change.new)

    engine_ip = Unicode(
        config=True,
        help=
        "IP on which to listen for engine connections. [default: loopback]",
    )

    def _engine_ip_default(self):
        return localhost()

    engine_transport = Unicode(
        'tcp',
        config=True,
        help="0MQ transport for engine connections. [default: tcp]")

    client_ip = Unicode(
        config=True,
        help=
        "IP on which to listen for client connections. [default: loopback]",
    )
    client_transport = Unicode(
        'tcp',
        config=True,
        help="0MQ transport for client connections. [default : tcp]")

    monitor_ip = Unicode(
        config=True,
        help="IP on which to listen for monitor messages. [default: loopback]",
    )
    monitor_transport = Unicode(
        'tcp',
        config=True,
        help="0MQ transport for monitor messages. [default : tcp]")

    _client_ip_default = _monitor_ip_default = _engine_ip_default

    monitor_url = Unicode('')

    db_class = Union(
        [Unicode(), Type()],
        default_value=DictDB,
        config=True,
        help="""The class to use for the DB backend

        Options include:

        SQLiteDB: SQLite
        MongoDB : use MongoDB
        DictDB  : in-memory storage (fastest, but be mindful of memory growth of the Hub)
        NoDB    : disable database altogether (default)

        """,
    )

    @validate("db_class")
    def _validate_db_class(self, proposal):
        value = proposal.value
        if isinstance(value, str):
            # if it's a string, import it
            value = _db_shortcuts.get(value.lower(), value)
            return import_item(value)
        return value

    registration_timeout = Integer(
        0,
        config=True,
        help="Engine registration timeout in seconds [default: max(30,"
        "10*heartmonitor.period)]",
    )

    @default("registration_timeout")
    def _registration_timeout_default(self):
        # heartmonitor period is in milliseconds, so 10x in seconds is .01
        return max(30, int(0.01 * HeartMonitor(parent=self).period))

    # not configurable
    db = Instance('ipyparallel.controller.dictdb.BaseDB', allow_none=True)
    heartmonitor = Instance('ipyparallel.controller.heartmonitor.HeartMonitor',
                            allow_none=True)

    ip = Unicode("127.0.0.1",
                 config=True,
                 help="""Set the controller ip for all connections.""")
    transport = Unicode("tcp",
                        config=True,
                        help="""Set the zmq transport for all connections.""")

    @observe('ip')
    def _ip_changed(self, change):
        new = change['new']
        self.engine_ip = new
        self.client_ip = new
        self.monitor_ip = new

    @observe('transport')
    def _transport_changed(self, change):
        new = change['new']
        self.engine_transport = new
        self.client_transport = new
        self.monitor_transport = new

    context = Instance(zmq.Context)

    @default("context")
    def _defaut_context(self):
        return zmq.Context.instance()

    # connection file contents
    engine_info = Dict()
    client_info = Dict()

    _logged_exhaustion = Dict()

    _random_port_count = Integer(0)

    def next_port(self, pool_name='common'):
        """Consume a port from our port pools"""
        if pool_name == 'client':
            if len(self.client_ports) > self.client_port_index:
                port = self.client_ports[self.client_port_index]
                self.client_port_index += 1
                return port
            elif self.client_ports and not self._logged_exhaustion.get(
                    "client"):
                self._logged_exhaustion['client'] = True
                # only log once
                self.log.warning(
                    f"Exhausted {len(self.client_ports)} client ports")
        elif pool_name == 'engine':
            if len(self.engine_ports) > self.engine_port_index:
                port = self.engine_ports[self.engine_port_index]
                self.engine_port_index += 1
                return port
            elif self.engine_ports and not self._logged_exhaustion.get(
                    "engine"):
                self._logged_exhaustion['engine'] = True
                self.log.warning(
                    f"Exhausted {len(self.engine_ports)} engine ports")

        # drawing from common pool
        if len(self.ports) > self.port_index:
            port = self.ports[self.port_index]
            self.port_index += 1
            return port
        elif self.ports and not self._logged_exhaustion.get("common"):
            self._logged_exhaustion['common'] = True
            self.log.warning(f"Exhausted {len(self.ports)} common ports")

        self._random_port_count += 1
        port = util.select_random_ports(1)[0]
        return port

    def construct_url(self, kind: str, channel: str, index=None):
        if kind == 'engine':
            info = self.engine_info
        elif kind == 'client':
            info = self.client_info
        elif kind == 'internal':
            info = self.internal_info
        else:
            raise ValueError(
                "kind must be 'engine', 'client', or 'internal', not {kind!r}")

        interface = info['interface']
        sep = '-' if interface.partition("://")[0] == 'ipc' else ':'
        port = info[channel]
        if index is not None:
            port = port[index]
        return f"{interface}{sep}{port}"

    def internal_url(self, channel, index=None):
        """return full zmq url for a named internal channel"""
        return self.construct_url('internal', channel, index)

    def bind(self, socket, url):
        """Bind a socket"""
        return util.bind(
            socket,
            url,
            curve_secretkey=self.curve_secretkey
            if self.enable_curve else None,
            curve_publickey=self.curve_publickey
            if self.enable_curve else None,
        )

    def connect(self, socket, url):
        return util.connect(
            socket,
            url,
            curve_serverkey=self.curve_publickey
            if self.enable_curve else None,
            curve_publickey=self.curve_publickey
            if self.enable_curve else None,
            curve_secretkey=self.curve_secretkey
            if self.enable_curve else None,
        )

    def client_url(self, channel, index=None):
        """return full zmq url for a named client channel"""
        return self.construct_url('client', channel, index)

    def engine_url(self, channel, index=None):
        """return full zmq url for a named engine channel"""
        return self.construct_url('engine', channel, index)

    def save_connection_dict(self, fname, cdict):
        """save a connection dict to json file."""
        fname = os.path.join(self.profile_dir.security_dir, fname)
        self.log.info("writing connection info to %s", fname)
        with open(fname, 'w') as f:
            f.write(json.dumps(cdict, indent=2))
        os.chmod(fname, stat.S_IRUSR | stat.S_IWUSR)

    def load_config_from_json(self):
        """load config from existing json connector files."""
        c = self.config
        self.log.debug("loading config from JSON")

        # load engine config

        fname = os.path.join(self.profile_dir.security_dir,
                             self.engine_json_file)
        self.log.info("loading connection info from %s", fname)
        with open(fname) as f:
            ecfg = json.loads(f.read())

        # json gives unicode, Session.key wants bytes
        c.Session.key = ecfg['key'].encode('ascii')

        xport, ip = ecfg['interface'].split('://')

        c.IPController.engine_ip = ip
        c.IPController.engine_transport = xport

        self.location = ecfg['location']
        if not self.engine_ssh_server:
            self.engine_ssh_server = ecfg['ssh']

        # load client config

        fname = os.path.join(self.profile_dir.security_dir,
                             self.client_json_file)
        self.log.info("loading connection info from %s", fname)
        with open(fname) as f:
            ccfg = json.loads(f.read())

        for key in ('key', 'registration', 'pack', 'unpack',
                    'signature_scheme'):
            assert ccfg[key] == ecfg[key], (
                "mismatch between engine and client info: %r" % key)

        xport, ip = ccfg['interface'].split('://')

        c.IPController.client_transport = xport
        c.IPController.client_ip = ip
        if not self.ssh_server:
            self.ssh_server = ccfg['ssh']

        self.engine_info = ecfg
        self.client_info = ccfg

    def cleanup_connection_files(self):
        if self.reuse_files:
            self.log.debug("leaving JSON connection files for reuse")
            return
        self.log.debug("cleaning up JSON connection files")
        for f in (self.client_json_file, self.engine_json_file):
            f = os.path.join(self.profile_dir.security_dir, f)
            try:
                os.remove(f)
            except Exception as e:
                self.log.error("Failed to cleanup connection file: %s", e)
            else:
                self.log.debug("removed %s", f)

    def load_secondary_config(self):
        """secondary config, loading from JSON and setting defaults"""
        if self.reuse_files:
            try:
                self.load_config_from_json()
            except (AssertionError, OSError) as e:
                self.log.error("Could not load config from JSON: %s" % e)
            else:
                # successfully loaded config from JSON, and reuse=True
                # no need to write back the same file
                self.write_connection_files = False

    def init_hub(self):
        if self.enable_curve:
            self.log.info(
                "Using CURVE security. Ignore warnings about disabled message signing."
            )

        c = self.config

        ctx = self.context
        loop = self.loop
        if 'TaskScheduler.scheme_name' in self.config:
            scheme = self.config.TaskScheduler.scheme_name
        else:
            from .task_scheduler import TaskScheduler

            scheme = TaskScheduler.scheme_name.default_value

        if self.engine_info:
            registration_port = self.engine_info['registration']
        else:
            registration_port = self.next_port('engine')

        # build connection dicts
        if not self.engine_info:
            self.engine_info = {
                'interface':
                f"{self.engine_transport}://{self.engine_ip}",
                'registration':
                registration_port,
                'control':
                self.next_port('engine'),
                'mux':
                self.next_port('engine'),
                'task':
                self.next_port('engine'),
                'iopub':
                self.next_port('engine'),
                'hb_ping':
                self.next_port('engine'),
                'hb_pong':
                self.next_port('engine'),
                BroadcastScheduler.port_name: [
                    self.next_port('engine')
                    for i in range(self.number_of_leaf_schedulers)
                ],
            }

        if not self.client_info:
            self.client_info = {
                'interface': f"{self.client_transport}://{self.client_ip}",
                'registration': registration_port,
                'control': self.next_port('client'),
                'mux': self.next_port('client'),
                'task': self.next_port('client'),
                'task_scheme': scheme,
                'iopub': self.next_port('client'),
                'notification': self.next_port('client'),
                BroadcastScheduler.port_name: self.next_port('client'),
            }
        if self.engine_transport == 'tcp':
            internal_interface = "tcp://127.0.0.1"
        else:
            internal_interface = self.engine_info['interface']

        broadcast_ids = []  # '0', '00', '01', '001', etc.
        # always a leading 0 for the root node
        for d in range(1, self.broadcast_scheduler_depth + 1):
            for i in range(2**d):
                broadcast_ids.append(format(i, f"0{d + 1}b"))
        self.internal_info = {
            'interface': internal_interface,
            BroadcastScheduler.port_name:
            {broadcast_id: self.next_port()
             for broadcast_id in broadcast_ids},
        }
        mon_port = self.next_port()
        self.monitor_url = f"{self.monitor_transport}://{self.monitor_ip}:{mon_port}"

        # debug port pool consumption
        if self.engine_ports:
            self.log.debug(
                f"Used {self.engine_port_index} / {len(self.engine_ports)} engine ports"
            )
        if self.client_ports:
            self.log.debug(
                f"Used {self.client_port_index} / {len(self.client_ports)} client ports"
            )
        if self.ports:
            self.log.debug(
                f"Used {self.port_index} / {len(self.ports)} common ports")
        if self._random_port_count:
            self.log.debug(f"Used {self._random_port_count} random ports")

        self.log.debug("Hub engine addrs: %s", self.engine_info)
        self.log.debug("Hub client addrs: %s", self.client_info)
        self.log.debug("Hub internal addrs: %s", self.internal_info)

        # Registrar socket
        query = ZMQStream(ctx.socket(zmq.ROUTER), loop)
        util.set_hwm(query, 0)
        self.bind(query, self.client_url('registration'))
        self.log.info("Hub listening on %s for registration.",
                      self.client_url('registration'))
        if self.client_ip != self.engine_ip:
            self.bind(query, self.engine_url('registration'))
            self.log.info("Hub listening on %s for registration.",
                          self.engine_url('registration'))

        ### Engine connections ###

        # heartbeat
        hm_config = Config()
        for key in ("Session", "HeartMonitor"):
            if key in self.config:
                hm_config[key] = self.config[key]
            hm_config.Session.key = self.session.key

        self.heartmonitor_process = Process(
            target=start_heartmonitor,
            kwargs=dict(
                ping_url=self.engine_url('hb_ping'),
                pong_url=self.engine_url('hb_pong'),
                monitor_url=disambiguate_url(self.monitor_url),
                config=hm_config,
                log_level=self.log.getEffectiveLevel(),
                curve_publickey=self.curve_publickey,
                curve_secretkey=self.curve_secretkey,
            ),
            daemon=True,
        )

        ### Client connections ###

        # Notifier socket
        notifier = ZMQStream(ctx.socket(zmq.PUB), loop)
        notifier.socket.SNDHWM = 0
        self.bind(notifier, self.client_url('notification'))

        ### build and launch the queues ###

        # monitor socket
        sub = ctx.socket(zmq.SUB)
        sub.RCVHWM = 0
        sub.setsockopt(zmq.SUBSCRIBE, b"")
        self.bind(sub, self.monitor_url)
        # self.bind(sub, 'inproc://monitor')
        sub = ZMQStream(sub, loop)

        # connect the db
        db_class = self.db_class
        self.log.info(f'Hub using DB backend: {self.db_class.__name__}')
        self.db = self.db_class(session=self.session.session,
                                parent=self,
                                log=self.log)
        time.sleep(0.25)

        # resubmit stream
        resubmit = ZMQStream(ctx.socket(zmq.DEALER), loop)
        url = util.disambiguate_url(self.client_url('task'))
        self.connect(resubmit, url)

        self.hub = Hub(
            loop=loop,
            session=self.session,
            monitor=sub,
            query=query,
            notifier=notifier,
            resubmit=resubmit,
            db=self.db,
            heartmonitor_period=HeartMonitor(parent=self).period,
            engine_info=self.engine_info,
            client_info=self.client_info,
            log=self.log,
            registration_timeout=self.registration_timeout,
            parent=self,
        )

        if self.write_connection_files:
            # save to new json config files
            base = {
                'key':
                self.session.key.decode('ascii'),
                'curve_serverkey':
                self.curve_publickey.decode("ascii")
                if self.enable_curve else None,
                'location':
                self.location,
                'pack':
                self.session.packer,
                'unpack':
                self.session.unpacker,
                'signature_scheme':
                self.session.signature_scheme,
            }

            cdict = {'ssh': self.ssh_server}
            cdict.update(self.client_info)
            cdict.update(base)
            self.save_connection_dict(self.client_json_file, cdict)

            edict = {'ssh': self.engine_ssh_server}
            edict.update(self.engine_info)
            edict.update(base)
            self.save_connection_dict(self.engine_json_file, edict)

        fname = "engines%s.json" % self.cluster_id
        self.hub.engine_state_file = os.path.join(self.profile_dir.log_dir,
                                                  fname)
        if self.restore_engines:
            self.hub._load_engine_state()

    def launch_python_scheduler(self, name, scheduler_args, children):
        if 'Process' in self.mq_class:
            # run the Python scheduler in a Process
            q = Process(
                target=launch_scheduler,
                kwargs=scheduler_args,
                name=name,
                daemon=True,
            )
            children.append(q)
        else:
            # single-threaded Controller
            scheduler_args['in_thread'] = True
            launch_scheduler(**scheduler_args)

    def get_python_scheduler_args(
        self,
        scheduler_name,
        scheduler_class,
        monitor_url,
        identity=None,
        in_addr=None,
        out_addr=None,
    ):
        if identity is not None:
            logname = f"{scheduler_name}-{identity}"
        else:
            logname = scheduler_name
        return {
            'scheduler_class':
            scheduler_class,
            'in_addr':
            in_addr or self.client_url(scheduler_name),
            'out_addr':
            out_addr or self.engine_url(scheduler_name),
            'mon_addr':
            monitor_url,
            'not_addr':
            disambiguate_url(self.client_url('notification')),
            'reg_addr':
            disambiguate_url(self.client_url('registration')),
            'identity':
            identity if identity is not None else bytes(
                scheduler_name, 'utf8'),
            'logname':
            logname,
            'loglevel':
            self.log_level,
            'log_url':
            self.log_url,
            'config':
            dict(self.config),
            'curve_secretkey':
            self.curve_secretkey if self.enable_curve else None,
            'curve_publickey':
            self.curve_publickey if self.enable_curve else None,
        }

    def launch_broadcast_schedulers(self, monitor_url, children):
        def launch_in_thread_or_process(scheduler_args, depth, identity):

            if 'Process' in self.mq_class:
                # run the Python scheduler in a Process
                q = Process(
                    target=launch_broadcast_scheduler,
                    kwargs=scheduler_args,
                    name=f"BroadcastScheduler(depth={depth}, id={identity})",
                    daemon=True,
                )
                children.append(q)
            else:
                # single-threaded Controller
                scheduler_args['in_thread'] = True
                launch_broadcast_scheduler(**scheduler_args)

        def recursively_start_schedulers(identity, depth):

            outgoing_id1 = identity + '0'
            outgoing_id2 = identity + '1'
            is_leaf = depth == self.broadcast_scheduler_depth
            is_root = depth == 0

            # FIXME: use localhost, not client ip for internal communication
            # this will still be localhost anyway for the most common cases
            # of localhost or */0.0.0.0
            if is_root:
                in_addr = self.client_url(BroadcastScheduler.port_name)
            else:
                # not root, use internal address

                in_addr = self.internal_url(
                    BroadcastScheduler.port_name,
                    index=identity,
                )

            scheduler_args = self.get_python_scheduler_args(
                BroadcastScheduler.port_name,
                BroadcastScheduler,
                monitor_url,
                identity,
                in_addr=in_addr,
                out_addr='ignored',
            )
            scheduler_args.pop('out_addr')
            # add broadcast args
            scheduler_args.update(
                outgoing_ids=[outgoing_id1, outgoing_id2],
                depth=depth,
                max_depth=self.broadcast_scheduler_depth,
                is_leaf=is_leaf,
            )

            if is_leaf:
                scheduler_args.update(out_addrs=[
                    self.engine_url(
                        BroadcastScheduler.port_name,
                        index=int(identity, 2),
                    )
                ], )
            else:
                scheduler_args.update(out_addrs=[
                    self.internal_url(BroadcastScheduler.port_name,
                                      index=outgoing_id1),
                    self.internal_url(BroadcastScheduler.port_name,
                                      index=outgoing_id2),
                ])
            launch_in_thread_or_process(scheduler_args,
                                        depth=depth,
                                        identity=identity)
            if not is_leaf:
                recursively_start_schedulers(outgoing_id1, depth + 1)
                recursively_start_schedulers(outgoing_id2, depth + 1)

        recursively_start_schedulers(identity='0', depth=0)

    def init_schedulers(self):
        children = self.children
        mq = import_item(str(self.mq_class))
        # ensure session key is shared across sessions
        self.config.Session.key = self.session.key
        ident = self.session.bsession

        def add_auth(q):
            """Add CURVE auth to a monitored queue"""
            if not self.enable_curve:
                return False
            q.setsockopt_in(zmq.CURVE_SERVER, 1)
            q.setsockopt_in(zmq.CURVE_SECRETKEY, self.curve_secretkey)
            q.setsockopt_out(zmq.CURVE_SERVER, 1)
            q.setsockopt_out(zmq.CURVE_SECRETKEY, self.curve_secretkey)
            # monitor is a client
            pub, secret = zmq.curve_keypair()
            q.setsockopt_mon(zmq.CURVE_SERVERKEY, self.curve_publickey)
            q.setsockopt_mon(zmq.CURVE_SECRETKEY, secret)
            q.setsockopt_mon(zmq.CURVE_PUBLICKEY, pub)

        # disambiguate url, in case of *
        monitor_url = disambiguate_url(self.monitor_url)
        # maybe_inproc = 'inproc://monitor' if self.use_threads else monitor_url
        # IOPub relay (in a Process)
        q = mq(zmq.SUB, zmq.PUB, zmq.PUB, b'iopub', b'N/A')
        add_auth(q)
        q.name = "IOPubScheduler"

        q.bind_in(self.engine_url('iopub'))
        q.setsockopt_in(zmq.SUBSCRIBE, b'')
        q.bind_out(self.client_url('iopub'))
        q.setsockopt_out(zmq.IDENTITY, ident + b"_iopub")
        q.connect_mon(monitor_url)
        q.daemon = True
        children.append(q)

        # Multiplexer Queue (in a Process)
        q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out')
        add_auth(q)
        q.name = "DirectScheduler"

        q.bind_in(self.client_url('mux'))
        q.setsockopt_in(zmq.IDENTITY, b'mux_in')
        q.bind_out(self.engine_url('mux'))
        q.setsockopt_out(zmq.IDENTITY, b'mux_out')
        q.connect_mon(monitor_url)
        q.daemon = True
        children.append(q)

        # Control Queue (in a Process)
        q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'incontrol', b'outcontrol')
        add_auth(q)
        q.name = "ControlScheduler"
        q.bind_in(self.client_url('control'))
        q.setsockopt_in(zmq.IDENTITY, b'control_in')
        q.bind_out(self.engine_url('control'))
        q.setsockopt_out(zmq.IDENTITY, b'control_out')
        q.connect_mon(monitor_url)
        q.daemon = True
        children.append(q)
        if 'TaskScheduler.scheme_name' in self.config:
            scheme = self.config.TaskScheduler.scheme_name
        else:
            scheme = TaskScheduler.scheme_name.default_value
        # Task Queue (in a Process)
        if scheme == 'pure':
            self.log.warning("task::using pure DEALER Task scheduler")
            q = mq(zmq.ROUTER, zmq.DEALER, zmq.PUB, b'intask', b'outtask')
            add_auth(q)
            q.name = "TaskScheduler(pure)"
            # q.setsockopt_out(zmq.HWM, hub.hwm)
            q.bind_in(self.client_url('task'))
            q.setsockopt_in(zmq.IDENTITY, b'task_in')
            q.bind_out(self.engine_url('task'))
            q.setsockopt_out(zmq.IDENTITY, b'task_out')
            q.connect_mon(monitor_url)
            q.daemon = True
            children.append(q)
        elif scheme == 'none':
            self.log.warning("task::using no Task scheduler")

        else:
            self.log.info("task::using Python %s Task scheduler" % scheme)
            self.launch_python_scheduler(
                'TaskScheduler',
                self.get_python_scheduler_args('task', TaskScheduler,
                                               monitor_url),
                children,
            )

        self.launch_broadcast_schedulers(monitor_url, children)

        # set unlimited HWM for all relay devices
        if hasattr(zmq, 'SNDHWM'):
            q = children[0]
            q.setsockopt_in(zmq.RCVHWM, 0)
            q.setsockopt_out(zmq.SNDHWM, 0)

            for q in children[1:]:
                if not hasattr(q, 'setsockopt_in'):
                    continue
                q.setsockopt_in(zmq.SNDHWM, 0)
                q.setsockopt_in(zmq.RCVHWM, 0)
                q.setsockopt_out(zmq.SNDHWM, 0)
                q.setsockopt_out(zmq.RCVHWM, 0)
                q.setsockopt_mon(zmq.SNDHWM, 0)

    def terminate_children(self):
        child_procs = []
        for child in self.children + [self.heartmonitor_process]:
            if isinstance(child, ProcessMonitoredQueue):
                child_procs.append(child.launcher)
            elif isinstance(child, Process):
                child_procs.append(child)
        if child_procs:
            self.log.critical("terminating children...")
            for child in child_procs:
                try:
                    child.terminate()
                except OSError:
                    # already dead
                    pass

    def handle_signal(self, sig, frame):
        self.log.critical("Received signal %i, shutting down", sig)
        self.terminate_children()
        self.loop.add_callback_from_signal(self.loop.stop)

    def init_signal(self):
        for sig in (SIGINT, SIGABRT, SIGTERM):
            signal(sig, self.handle_signal)

    def forward_logging(self):
        if self.log_url:
            self.log.info("Forwarding logging to %s" % self.log_url)
            context = zmq.Context.instance()
            lsock = context.socket(zmq.PUB)
            lsock.connect(self.log_url)
            handler = PUBHandler(lsock)
            handler.root_topic = 'controller'
            handler.setLevel(self.log_level)
            self.log.addHandler(handler)

    @catch_config_error
    def initialize(self, argv=None):
        super().initialize(argv)
        self.forward_logging()
        self.load_secondary_config()
        self.init_hub()
        self.init_schedulers()

    def start(self):
        # Start the subprocesses:
        # children must be started before signals are setup,
        # otherwise signal-handling will fire multiple times
        for child in self.children:
            child.start()
            if hasattr(child, 'launcher'):
                # apply name to actual process/thread for logging
                setattr(child.launcher, 'name', child.name)
            if not self.use_threads:
                process = getattr(child, 'launcher', child)
                self.log.debug(f"Started process {child.name}: {process.pid}")
            else:
                self.log.debug(f"Started thread {child.name}")

        self.heartmonitor_process.start()
        self.log.info(
            f"Heartmonitor beating every {self.hub.heartmonitor_period}ms")

        self.init_signal()

        try:
            self.loop.start()
        except KeyboardInterrupt:
            self.log.critical("Interrupted, Exiting...\n")
        finally:
            self.loop.close(all_fds=True)
            self.cleanup_connection_files()
示例#12
0
class KernelSpecManager(LoggingConfigurable):

    kernel_spec_class = Type(
        KernelSpec,
        config=True,
        help="""The kernel spec class.  This is configurable to allow
        subclassing of the KernelSpecManager for customized behavior.
        """,
    )

    ensure_native_kernel = Bool(
        True,
        config=True,
        help="""If there is no Python kernelspec registered and the IPython
        kernel is available, ensure it is added to the spec list.
        """,
    )

    data_dir = Unicode()

    def _data_dir_default(self):
        return jupyter_data_dir()

    user_kernel_dir = Unicode()

    def _user_kernel_dir_default(self):
        return pjoin(self.data_dir, "kernels")

    whitelist = Set(
        config=True,
        help="""Deprecated, use `KernelSpecManager.allowed_kernelspecs`
        """,
    )
    allowed_kernelspecs = Set(
        config=True,
        help="""List of allowed kernel names.

        By default, all installed kernels are allowed.
        """,
    )
    kernel_dirs = List(
        help=
        "List of kernel directories to search. Later ones take priority over earlier."
    )

    _deprecated_aliases = {
        "whitelist": ("allowed_kernelspecs", "7.0"),
    }

    # Method copied from
    # https://github.com/jupyterhub/jupyterhub/blob/d1a85e53dccfc7b1dd81b0c1985d158cc6b61820/jupyterhub/auth.py#L143-L161
    @observe(*list(_deprecated_aliases))
    def _deprecated_trait(self, change):
        """observer for deprecated traits"""
        old_attr = change.name
        new_attr, version = self._deprecated_aliases.get(old_attr)
        new_value = getattr(self, new_attr)
        if new_value != change.new:
            # only warn if different
            # protects backward-compatible config from warnings
            # if they set the same value under both names
            self.log.warning(("{cls}.{old} is deprecated in jupyter_client "
                              "{version}, use {cls}.{new} instead").format(
                                  cls=self.__class__.__name__,
                                  old=old_attr,
                                  new=new_attr,
                                  version=version,
                              ))
            setattr(self, new_attr, change.new)

    def _kernel_dirs_default(self):
        dirs = jupyter_path("kernels")
        # At some point, we should stop adding .ipython/kernels to the path,
        # but the cost to keeping it is very small.
        try:
            from IPython.paths import get_ipython_dir  # type: ignore
        except ImportError:
            try:
                from IPython.utils.path import get_ipython_dir  # type: ignore
            except ImportError:
                # no IPython, no ipython dir
                get_ipython_dir = None
        if get_ipython_dir is not None:
            dirs.append(os.path.join(get_ipython_dir(), "kernels"))
        return dirs

    def find_kernel_specs(self):
        """Returns a dict mapping kernel names to resource directories."""
        d = {}
        for kernel_dir in self.kernel_dirs:
            kernels = _list_kernels_in(kernel_dir)
            for kname, spec in kernels.items():
                if kname not in d:
                    self.log.debug("Found kernel %s in %s", kname, kernel_dir)
                    d[kname] = spec

        if self.ensure_native_kernel and NATIVE_KERNEL_NAME not in d:
            try:
                from ipykernel.kernelspec import RESOURCES  # type: ignore

                self.log.debug(
                    "Native kernel (%s) available from %s",
                    NATIVE_KERNEL_NAME,
                    RESOURCES,
                )
                d[NATIVE_KERNEL_NAME] = RESOURCES
            except ImportError:
                self.log.warning("Native kernel (%s) is not available",
                                 NATIVE_KERNEL_NAME)

        if self.allowed_kernelspecs:
            # filter if there's an allow list
            d = {
                name: spec
                for name, spec in d.items() if name in self.allowed_kernelspecs
            }
        return d
        # TODO: Caching?

    def _get_kernel_spec_by_name(self, kernel_name, resource_dir):
        """Returns a :class:`KernelSpec` instance for a given kernel_name
        and resource_dir.
        """
        kspec = None
        if kernel_name == NATIVE_KERNEL_NAME:
            try:
                from ipykernel.kernelspec import RESOURCES, get_kernel_dict
            except ImportError:
                # It should be impossible to reach this, but let's play it safe
                pass
            else:
                if resource_dir == RESOURCES:
                    kspec = self.kernel_spec_class(resource_dir=resource_dir,
                                                   **get_kernel_dict())
        if not kspec:
            kspec = self.kernel_spec_class.from_resource_dir(resource_dir)

        if not KPF.instance(
                parent=self.parent).is_provisioner_available(kspec):
            raise NoSuchKernel(kernel_name)

        return kspec

    def _find_spec_directory(self, kernel_name):
        """Find the resource directory of a named kernel spec"""
        for kernel_dir in [kd for kd in self.kernel_dirs if os.path.isdir(kd)]:
            files = os.listdir(kernel_dir)
            for f in files:
                path = pjoin(kernel_dir, f)
                if f.lower() == kernel_name and _is_kernel_dir(path):
                    return path

        if kernel_name == NATIVE_KERNEL_NAME:
            try:
                from ipykernel.kernelspec import RESOURCES
            except ImportError:
                pass
            else:
                return RESOURCES

    def get_kernel_spec(self, kernel_name):
        """Returns a :class:`KernelSpec` instance for the given kernel_name.

        Raises :exc:`NoSuchKernel` if the given kernel name is not found.
        """
        if not _is_valid_kernel_name(kernel_name):
            self.log.warning(
                f"Kernelspec name {kernel_name} is invalid: {_kernel_name_description}"
            )

        resource_dir = self._find_spec_directory(kernel_name.lower())
        if resource_dir is None:
            self.log.warning(f"Kernelspec name {kernel_name} cannot be found!")
            raise NoSuchKernel(kernel_name)

        return self._get_kernel_spec_by_name(kernel_name, resource_dir)

    def get_all_specs(self):
        """Returns a dict mapping kernel names to kernelspecs.

        Returns a dict of the form::

            {
              'kernel_name': {
                'resource_dir': '/path/to/kernel_name',
                'spec': {"the spec itself": ...}
              },
              ...
            }
        """
        d = self.find_kernel_specs()
        res = {}
        for kname, resource_dir in d.items():
            try:
                if self.__class__ is KernelSpecManager:
                    spec = self._get_kernel_spec_by_name(kname, resource_dir)
                else:
                    # avoid calling private methods in subclasses,
                    # which may have overridden find_kernel_specs
                    # and get_kernel_spec, but not the newer get_all_specs
                    spec = self.get_kernel_spec(kname)

                res[kname] = {
                    "resource_dir": resource_dir,
                    "spec": spec.to_dict()
                }
            except NoSuchKernel:
                pass  # The appropriate warning has already been logged
            except Exception:
                self.log.warning("Error loading kernelspec %r",
                                 kname,
                                 exc_info=True)
        return res

    def remove_kernel_spec(self, name):
        """Remove a kernel spec directory by name.

        Returns the path that was deleted.
        """
        save_native = self.ensure_native_kernel
        try:
            self.ensure_native_kernel = False
            specs = self.find_kernel_specs()
        finally:
            self.ensure_native_kernel = save_native
        spec_dir = specs[name]
        self.log.debug("Removing %s", spec_dir)
        if os.path.islink(spec_dir):
            os.remove(spec_dir)
        else:
            shutil.rmtree(spec_dir)
        return spec_dir

    def _get_destination_dir(self, kernel_name, user=False, prefix=None):
        if user:
            return os.path.join(self.user_kernel_dir, kernel_name)
        elif prefix:
            return os.path.join(os.path.abspath(prefix), "share", "jupyter",
                                "kernels", kernel_name)
        else:
            return os.path.join(SYSTEM_JUPYTER_PATH[0], "kernels", kernel_name)

    def install_kernel_spec(self,
                            source_dir,
                            kernel_name=None,
                            user=False,
                            replace=None,
                            prefix=None):
        """Install a kernel spec by copying its directory.

        If ``kernel_name`` is not given, the basename of ``source_dir`` will
        be used.

        If ``user`` is False, it will attempt to install into the systemwide
        kernel registry. If the process does not have appropriate permissions,
        an :exc:`OSError` will be raised.

        If ``prefix`` is given, the kernelspec will be installed to
        PREFIX/share/jupyter/kernels/KERNEL_NAME. This can be sys.prefix
        for installation inside virtual or conda envs.
        """
        source_dir = source_dir.rstrip("/\\")
        if not kernel_name:
            kernel_name = os.path.basename(source_dir)
        kernel_name = kernel_name.lower()
        if not _is_valid_kernel_name(kernel_name):
            raise ValueError("Invalid kernel name %r.  %s" %
                             (kernel_name, _kernel_name_description))

        if user and prefix:
            raise ValueError(
                "Can't specify both user and prefix. Please choose one or the other."
            )

        if replace is not None:
            warnings.warn(
                "replace is ignored. Installing a kernelspec always replaces an existing "
                "installation",
                DeprecationWarning,
                stacklevel=2,
            )

        destination = self._get_destination_dir(kernel_name,
                                                user=user,
                                                prefix=prefix)
        self.log.debug("Installing kernelspec in %s", destination)

        kernel_dir = os.path.dirname(destination)
        if kernel_dir not in self.kernel_dirs:
            self.log.warning(
                "Installing to %s, which is not in %s. The kernelspec may not be found.",
                kernel_dir,
                self.kernel_dirs,
            )

        if os.path.isdir(destination):
            self.log.info("Removing existing kernelspec in %s", destination)
            shutil.rmtree(destination)

        shutil.copytree(source_dir, destination)
        self.log.info("Installed kernelspec %s in %s", kernel_name,
                      destination)
        return destination

    def install_native_kernel_spec(self, user=False):
        """DEPRECATED: Use ipykernel.kernelspec.install"""
        warnings.warn(
            "install_native_kernel_spec is deprecated."
            " Use ipykernel.kernelspec import install.",
            stacklevel=2,
        )
        from ipykernel.kernelspec import install

        install(self, user=user)
示例#13
0
class EnterpriseGatewayConfigMixin(Configurable):
    # Server IP / PORT binding
    port_env = 'EG_PORT'
    port_default_value = 8888
    port = Integer(port_default_value,
                   config=True,
                   help='Port on which to listen (EG_PORT env var)')

    @default('port')
    def port_default(self):
        return int(
            os.getenv(self.port_env,
                      os.getenv('KG_PORT', self.port_default_value)))

    port_retries_env = 'EG_PORT_RETRIES'
    port_retries_default_value = 50
    port_retries = Integer(
        port_retries_default_value,
        config=True,
        help="""Number of ports to try if the specified port is not available
                           (EG_PORT_RETRIES env var)""")

    @default('port_retries')
    def port_retries_default(self):
        return int(
            os.getenv(
                self.port_retries_env,
                os.getenv('KG_PORT_RETRIES', self.port_retries_default_value)))

    ip_env = 'EG_IP'
    ip_default_value = '127.0.0.1'
    ip = Unicode(ip_default_value,
                 config=True,
                 help='IP address on which to listen (EG_IP env var)')

    @default('ip')
    def ip_default(self):
        return os.getenv(self.ip_env, os.getenv('KG_IP',
                                                self.ip_default_value))

    # Base URL
    base_url_env = 'EG_BASE_URL'
    base_url_default_value = '/'
    base_url = Unicode(
        base_url_default_value,
        config=True,
        help=
        'The base path for mounting all API resources (EG_BASE_URL env var)')

    @default('base_url')
    def base_url_default(self):
        return os.getenv(self.base_url_env,
                         os.getenv('KG_BASE_URL', self.base_url_default_value))

    # Token authorization
    auth_token_env = 'EG_AUTH_TOKEN'
    auth_token = Unicode(
        config=True,
        help=
        'Authorization token required for all requests (EG_AUTH_TOKEN env var)'
    )

    @default('auth_token')
    def _auth_token_default(self):
        return os.getenv(self.auth_token_env, os.getenv('KG_AUTH_TOKEN', ''))

    # Begin CORS headers
    allow_credentials_env = 'EG_ALLOW_CREDENTIALS'
    allow_credentials = Unicode(
        config=True,
        help=
        'Sets the Access-Control-Allow-Credentials header. (EG_ALLOW_CREDENTIALS env var)'
    )

    @default('allow_credentials')
    def allow_credentials_default(self):
        return os.getenv(self.allow_credentials_env,
                         os.getenv('KG_ALLOW_CREDENTIALS', ''))

    allow_headers_env = 'EG_ALLOW_HEADERS'
    allow_headers = Unicode(
        config=True,
        help=
        'Sets the Access-Control-Allow-Headers header. (EG_ALLOW_HEADERS env var)'
    )

    @default('allow_headers')
    def allow_headers_default(self):
        return os.getenv(self.allow_headers_env,
                         os.getenv('KG_ALLOW_HEADERS', ''))

    allow_methods_env = 'EG_ALLOW_METHODS'
    allow_methods = Unicode(
        config=True,
        help=
        'Sets the Access-Control-Allow-Methods header. (EG_ALLOW_METHODS env var)'
    )

    @default('allow_methods')
    def allow_methods_default(self):
        return os.getenv(self.allow_methods_env,
                         os.getenv('KG_ALLOW_METHODS', ''))

    allow_origin_env = 'EG_ALLOW_ORIGIN'
    allow_origin = Unicode(
        config=True,
        help=
        'Sets the Access-Control-Allow-Origin header. (EG_ALLOW_ORIGIN env var)'
    )

    @default('allow_origin')
    def allow_origin_default(self):
        return os.getenv(self.allow_origin_env,
                         os.getenv('KG_ALLOW_ORIGIN', ''))

    expose_headers_env = 'EG_EXPOSE_HEADERS'
    expose_headers = Unicode(
        config=True,
        help=
        'Sets the Access-Control-Expose-Headers header. (EG_EXPOSE_HEADERS env var)'
    )

    @default('expose_headers')
    def expose_headers_default(self):
        return os.getenv(self.expose_headers_env,
                         os.getenv('KG_EXPOSE_HEADERS', ''))

    trust_xheaders_env = 'EG_TRUST_XHEADERS'
    trust_xheaders = CBool(
        False,
        config=True,
        help="""Use x-* header values for overriding the remote-ip, useful when
                           application is behing a proxy. (EG_TRUST_XHEADERS env var)"""
    )

    @default('trust_xheaders')
    def trust_xheaders_default(self):
        return strtobool(
            os.getenv(self.trust_xheaders_env,
                      os.getenv('KG_TRUST_XHEADERS', 'False')))

    certfile_env = 'EG_CERTFILE'
    certfile = Unicode(
        None,
        config=True,
        allow_none=True,
        help=
        'The full path to an SSL/TLS certificate file. (EG_CERTFILE env var)')

    @default('certfile')
    def certfile_default(self):
        return os.getenv(self.certfile_env, os.getenv('KG_CERTFILE'))

    keyfile_env = 'EG_KEYFILE'
    keyfile = Unicode(
        None,
        config=True,
        allow_none=True,
        help=
        'The full path to a private key file for usage with SSL/TLS. (EG_KEYFILE env var)'
    )

    @default('keyfile')
    def keyfile_default(self):
        return os.getenv(self.keyfile_env, os.getenv('KG_KEYFILE'))

    client_ca_env = 'EG_CLIENT_CA'
    client_ca = Unicode(
        None,
        config=True,
        allow_none=True,
        help="""The full path to a certificate authority certificate for SSL/TLS
                        client authentication. (EG_CLIENT_CA env var)""")

    @default('client_ca')
    def client_ca_default(self):
        return os.getenv(self.client_ca_env, os.getenv('KG_CLIENT_CA'))

    ssl_version_env = 'EG_SSL_VERSION'
    ssl_version_default_value = ssl.PROTOCOL_TLSv1_2
    ssl_version = Integer(
        None,
        config=True,
        allow_none=True,
        help="""Sets the SSL version to use for the web socket
                          connection. (EG_SSL_VERSION env var)""")

    @default('ssl_version')
    def ssl_version_default(self):
        ssl_from_env = os.getenv(self.ssl_version_env,
                                 os.getenv('KG_SSL_VERSION'))
        return ssl_from_env if ssl_from_env is None else int(ssl_from_env)

    max_age_env = 'EG_MAX_AGE'
    max_age = Unicode(
        config=True,
        help='Sets the Access-Control-Max-Age header. (EG_MAX_AGE env var)')

    @default('max_age')
    def max_age_default(self):
        return os.getenv(self.max_age_env, os.getenv('KG_MAX_AGE', ''))

    # End CORS headers

    max_kernels_env = 'EG_MAX_KERNELS'
    max_kernels = Integer(
        None,
        config=True,
        allow_none=True,
        help=
        """Limits the number of kernel instances allowed to run by this gateway.
                          Unbounded by default. (EG_MAX_KERNELS env var)""")

    @default('max_kernels')
    def max_kernels_default(self):
        val = os.getenv(self.max_kernels_env, os.getenv('KG_MAX_KERNELS'))
        return val if val is None else int(val)

    default_kernel_name_env = 'EG_DEFAULT_KERNEL_NAME'
    default_kernel_name = Unicode(
        config=True,
        help=
        'Default kernel name when spawning a kernel (EG_DEFAULT_KERNEL_NAME env var)'
    )

    @default('default_kernel_name')
    def default_kernel_name_default(self):
        # defaults to Jupyter's default kernel name on empty string
        return os.getenv(self.default_kernel_name_env,
                         os.getenv('KG_DEFAULT_KERNEL_NAME', ''))

    list_kernels_env = 'EG_LIST_KERNELS'
    list_kernels = Bool(
        config=True,
        help=
        """Permits listing of the running kernels using API endpoints /api/kernels
                        and /api/sessions. (EG_LIST_KERNELS env var) Note: Jupyter Notebook
                        allows this by default but Jupyter Enterprise Gateway does not."""
    )

    @default('list_kernels')
    def list_kernels_default(self):
        return os.getenv(self.list_kernels_env,
                         os.getenv('KG_LIST_KERNELS',
                                   'False')).lower() == 'true'

    env_whitelist_env = 'EG_ENV_WHITELIST'
    env_whitelist = List(
        config=True,
        help="""Environment variables allowed to be set when a client requests a
                         new kernel. Use '*' to allow all environment variables sent in the request.
                         (EG_ENV_WHITELIST env var)""")

    @default('env_whitelist')
    def env_whitelist_default(self):
        return os.getenv(self.env_whitelist_env,
                         os.getenv('KG_ENV_WHITELIST', '')).split(',')

    env_process_whitelist_env = 'EG_ENV_PROCESS_WHITELIST'
    env_process_whitelist = List(
        config=True,
        help="""Environment variables allowed to be inherited
                                 from the spawning process by the kernel. (EG_ENV_PROCESS_WHITELIST env var)"""
    )

    @default('env_process_whitelist')
    def env_process_whitelist_default(self):
        return os.getenv(self.env_process_whitelist_env,
                         os.getenv('KG_ENV_PROCESS_WHITELIST', '')).split(',')

    kernel_headers_env = 'EG_KERNEL_HEADERS'
    kernel_headers = List(
        config=True,
        help="""Request headers to make available to kernel launch framework.
                          (EG_KERNEL_HEADERS env var)""")

    @default('kernel_headers')
    def kernel_headers_default(self):
        default_headers = os.getenv(self.kernel_headers_env)
        return default_headers.split(',') if default_headers else []

    # Remote hosts
    remote_hosts_env = 'EG_REMOTE_HOSTS'
    remote_hosts_default_value = 'localhost'
    remote_hosts = List(
        default_value=[remote_hosts_default_value],
        config=True,
        help=
        """Bracketed comma-separated list of hosts on which DistributedProcessProxy
                        kernels will be launched e.g., ['host1','host2']. (EG_REMOTE_HOSTS env var
                        - non-bracketed, just comma-separated)""")

    @default('remote_hosts')
    def remote_hosts_default(self):
        return os.getenv(self.remote_hosts_env,
                         self.remote_hosts_default_value).split(',')

    # Yarn endpoint
    yarn_endpoint_env = 'EG_YARN_ENDPOINT'
    yarn_endpoint = Unicode(
        None,
        config=True,
        allow_none=True,
        help=
        """The http url specifying the YARN Resource Manager. Note: If this value is NOT set,
                            the YARN library will use the files within the local HADOOP_CONFIG_DIR to determine the
                            active resource manager. (EG_YARN_ENDPOINT env var)"""
    )

    @default('yarn_endpoint')
    def yarn_endpoint_default(self):
        return os.getenv(self.yarn_endpoint_env)

    # Alt Yarn endpoint
    alt_yarn_endpoint_env = 'EG_ALT_YARN_ENDPOINT'
    alt_yarn_endpoint = Unicode(
        None,
        config=True,
        allow_none=True,
        help=
        """The http url specifying the alternate YARN Resource Manager.  This value should
                                be set when YARN Resource Managers are configured for high availability.  Note: If both
                                YARN endpoints are NOT set, the YARN library will use the files within the local
                                HADOOP_CONFIG_DIR to determine the active resource manager.
                                (EG_ALT_YARN_ENDPOINT env var)""")

    @default('alt_yarn_endpoint')
    def alt_yarn_endpoint_default(self):
        return os.getenv(self.alt_yarn_endpoint_env)

    yarn_endpoint_security_enabled_env = 'EG_YARN_ENDPOINT_SECURITY_ENABLED'
    yarn_endpoint_security_enabled_default_value = False
    yarn_endpoint_security_enabled = Bool(
        yarn_endpoint_security_enabled_default_value,
        config=True,
        help="""Is YARN Kerberos/SPNEGO Security enabled (True/False).
                                          (EG_YARN_ENDPOINT_SECURITY_ENABLED env var)"""
    )

    @default('yarn_endpoint_security_enabled')
    def yarn_endpoint_security_enabled_default(self):
        return bool(
            os.getenv(self.yarn_endpoint_security_enabled_env,
                      self.yarn_endpoint_security_enabled_default_value))

    # Conductor endpoint
    conductor_endpoint_env = 'EG_CONDUCTOR_ENDPOINT'
    conductor_endpoint_default_value = None
    conductor_endpoint = Unicode(
        conductor_endpoint_default_value,
        allow_none=True,
        config=True,
        help="""The http url for accessing the Conductor REST API.
                                 (EG_CONDUCTOR_ENDPOINT env var)""")

    @default('conductor_endpoint')
    def conductor_endpoint_default(self):
        return os.getenv(self.conductor_endpoint_env,
                         self.conductor_endpoint_default_value)

    _log_formatter_cls = LogFormatter  # traitlet default is LevelFormatter

    @default('log_format')
    def _default_log_format(self):
        """override default log format to include milliseconds"""
        return u"%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s]%(end_color)s %(message)s"

    # Impersonation enabled
    impersonation_enabled_env = 'EG_IMPERSONATION_ENABLED'
    impersonation_enabled = Bool(
        False,
        config=True,
        help=
        """Indicates whether impersonation will be performed during kernel launch.
                                 (EG_IMPERSONATION_ENABLED env var)""")

    @default('impersonation_enabled')
    def impersonation_enabled_default(self):
        return bool(
            os.getenv(self.impersonation_enabled_env, 'false').lower() ==
            'true')

    # Unauthorized users
    unauthorized_users_env = 'EG_UNAUTHORIZED_USERS'
    unauthorized_users_default_value = 'root'
    unauthorized_users = Set(
        default_value={unauthorized_users_default_value},
        config=True,
        help=
        """Comma-separated list of user names (e.g., ['root','admin']) against which
                             KERNEL_USERNAME will be compared.  Any match (case-sensitive) will prevent the
                             kernel's launch and result in an HTTP 403 (Forbidden) error.
                             (EG_UNAUTHORIZED_USERS env var - non-bracketed, just comma-separated)"""
    )

    @default('unauthorized_users')
    def unauthorized_users_default(self):
        return os.getenv(self.unauthorized_users_env,
                         self.unauthorized_users_default_value).split(',')

    # Authorized users
    authorized_users_env = 'EG_AUTHORIZED_USERS'
    authorized_users = Set(
        config=True,
        help=
        """Comma-separated list of user names (e.g., ['bob','alice']) against which
                           KERNEL_USERNAME will be compared.  Any match (case-sensitive) will allow the kernel's
                           launch, otherwise an HTTP 403 (Forbidden) error will be raised.  The set of unauthorized
                           users takes precedence. This option should be used carefully as it can dramatically limit
                           who can launch kernels.  (EG_AUTHORIZED_USERS env var - non-bracketed,
                           just comma-separated)""")

    @default('authorized_users')
    def authorized_users_default(self):
        au_env = os.getenv(self.authorized_users_env)
        return au_env.split(',') if au_env is not None else []

    # Port range
    port_range_env = 'EG_PORT_RANGE'
    port_range_default_value = "0..0"
    port_range = Unicode(
        port_range_default_value,
        config=True,
        help=
        """Specifies the lower and upper port numbers from which ports are created.
                         The bounded values are separated by '..' (e.g., 33245..34245 specifies a range of 1000 ports
                         to be randomly selected). A range of zero (e.g., 33245..33245 or 0..0) disables port-range
                         enforcement.  (EG_PORT_RANGE env var)""")

    @default('port_range')
    def port_range_default(self):
        return os.getenv(self.port_range_env, self.port_range_default_value)

    # Max Kernels per User
    max_kernels_per_user_env = 'EG_MAX_KERNELS_PER_USER'
    max_kernels_per_user_default_value = -1
    max_kernels_per_user = Integer(
        max_kernels_per_user_default_value,
        config=True,
        help="""Specifies the maximum number of kernels a user can have active
                                   simultaneously.  A value of -1 disables enforcement.
                                   (EG_MAX_KERNELS_PER_USER env var)""")

    @default('max_kernels_per_user')
    def max_kernels_per_user_default(self):
        return int(
            os.getenv(self.max_kernels_per_user_env,
                      self.max_kernels_per_user_default_value))

    ws_ping_interval_env = 'EG_WS_PING_INTERVAL_SECS'
    ws_ping_interval_default_value = 30
    ws_ping_interval = Integer(
        ws_ping_interval_default_value,
        config=True,
        help=
        """Specifies the ping interval(in seconds) that should be used by zmq port
                                     associated withspawned kernels.Set this variable to 0 to disable ping mechanism.
                                    (EG_WS_PING_INTERVAL_SECS env var)""")

    @default('ws_ping_interval')
    def ws_ping_interval_default(self):
        return int(
            os.getenv(self.ws_ping_interval_env,
                      self.ws_ping_interval_default_value))

    # Dynamic Update Interval
    dynamic_config_interval_env = 'EG_DYNAMIC_CONFIG_INTERVAL'
    dynamic_config_interval_default_value = 0
    dynamic_config_interval = Integer(
        dynamic_config_interval_default_value,
        min=0,
        config=True,
        help=
        """Specifies the number of seconds configuration files are polled for
                                      changes.  A value of 0 or less disables dynamic config updates.
                                      (EG_DYNAMIC_CONFIG_INTERVAL env var)""")

    @default('dynamic_config_interval')
    def dynamic_config_interval_default(self):
        return int(
            os.getenv(self.dynamic_config_interval_env,
                      self.dynamic_config_interval_default_value))

    @observe('dynamic_config_interval')
    def dynamic_config_interval_changed(self, event):
        prev_val = event['old']
        self.dynamic_config_interval = event['new']
        if self.dynamic_config_interval != prev_val:
            # Values are different.  Stop the current poller.  If new value is > 0, start a poller.
            if self.dynamic_config_poller:
                self.dynamic_config_poller.stop()
                self.dynamic_config_poller = None

            if self.dynamic_config_interval <= 0:
                self.log.warning(
                    "Dynamic configuration updates have been disabled and cannot be re-enabled "
                    "without restarting Enterprise Gateway!")
            # The interval has been changed, but still positive
            elif prev_val > 0 and hasattr(self, "init_dynamic_configs"):
                self.init_dynamic_configs()  # Restart the poller

    dynamic_config_poller = None

    kernel_spec_manager = Instance(
        "jupyter_client.kernelspec.KernelSpecManager", allow_none=True)

    kernel_spec_manager_class = Type(
        default_value="jupyter_client.kernelspec.KernelSpecManager",
        config=True,
        help="""
        The kernel spec manager class to use. Must be a subclass
        of `jupyter_client.kernelspec.KernelSpecManager`.
        """)

    kernel_spec_cache_class = Type(
        default_value="enterprise_gateway.services.kernelspecs.KernelSpecCache",
        config=True,
        help="""
        The kernel spec cache class to use. Must be a subclass
        of `enterprise_gateway.services.kernelspecs.KernelSpecCache`.
        """)

    kernel_manager_class = Type(
        klass=
        "enterprise_gateway.services.kernels.remotemanager.RemoteMappingKernelManager",
        default_value=
        "enterprise_gateway.services.kernels.remotemanager.RemoteMappingKernelManager",
        config=True,
        help="""
        The kernel manager class to use. Must be a subclass
        of `enterprise_gateway.services.kernels.RemoteMappingKernelManager`.
        """)

    kernel_session_manager_class = Type(
        klass=
        "enterprise_gateway.services.sessions.kernelsessionmanager.KernelSessionManager",
        default_value=
        "enterprise_gateway.services.sessions.kernelsessionmanager.FileKernelSessionManager",
        config=True,
        help="""
        The kernel session manager class to use. Must be a subclass
        of `enterprise_gateway.services.sessions.KernelSessionManager`.
        """)
示例#14
0
class SageTerminalApp(TerminalIPythonApp):
    name = u'Sage'
    crash_handler_class = SageCrashHandler

    test_shell = Bool(False, help='Whether the shell is a test shell')
    test_shell.tag(config=True)
    shell_class = Type(InteractiveShell, help='Type of the shell')
    shell_class.tag(config=True)

    def load_config_file(self, *args, **kwds):
        r"""
        Merges a config file with the default sage config.

        .. note::

            This code is based on :meth:`Application.update_config`.

        TESTS:

        Test that :trac:`15972` has been fixed::

            sage: from sage.misc.temporary_file import tmp_dir
            sage: from sage.repl.interpreter import SageTerminalApp
            sage: d = tmp_dir()
            sage: from IPython.paths import get_ipython_dir
            sage: IPYTHONDIR = get_ipython_dir()
            sage: os.environ['IPYTHONDIR'] = d
            sage: SageTerminalApp().load_config_file()
            sage: os.environ['IPYTHONDIR'] = IPYTHONDIR
        """
        super(SageTerminalApp, self).load_config_file(*args, **kwds)

        newconfig = copy.deepcopy(DEFAULT_SAGE_CONFIG)

        # merge in the config loaded from file
        newconfig.merge(self.config)

        self.config = newconfig

    def init_shell(self):
        r"""
        Initialize the :class:`SageInteractiveShell` instance.

        .. note::

            This code is based on
            :meth:`TerminalIPythonApp.init_shell`.

        EXAMPLES::

            sage: from sage.repl.interpreter import SageTerminalApp, DEFAULT_SAGE_CONFIG
            sage: app = SageTerminalApp.instance()
            sage: app.shell
            <sage.repl.interpreter.SageTestShell object at 0x...>
        """
        # Shell initialization
        self.shell = self.shell_class.instance(parent=self,
                                               config=self.config,
                                               display_banner=False,
                                               profile_dir=self.profile_dir,
                                               ipython_dir=self.ipython_dir)
        self.shell.configurables.append(self)
        self.shell.has_sage_extensions = SAGE_EXTENSION in self.extensions

        # Load the %lprun extension if available
        try:
            import line_profiler
        except ImportError:
            pass
        else:
            self.extensions.append('line_profiler')

        if self.shell.has_sage_extensions:
            self.extensions.remove(SAGE_EXTENSION)

            # load sage extension here to get a crash if
            # something is wrong with the sage library
            self.shell.extension_manager.load_extension(SAGE_EXTENSION)
示例#15
0
class AsyncKernelManager(KernelManager):
    """Manages kernels in an asynchronous manner """

    client_class = DottedObjectName('jupyter_client.asynchronous.AsyncKernelClient')
    client_factory = Type(klass='jupyter_client.asynchronous.AsyncKernelClient')

    async def _launch_kernel(self, kernel_cmd, **kw):
        """actually launch the kernel

        override in a subclass to launch kernel subprocesses differently
        """
        res = launch_kernel(kernel_cmd, **kw)
        return res

    async def start_kernel(self, **kw):
        """Starts a kernel in a separate process in an asynchronous manner.

        If random ports (port=0) are being used, this method must be called
        before the channels are created.

        Parameters
        ----------
        `**kw` : optional
             keyword arguments that are passed down to build the kernel_cmd
             and launching the kernel (e.g. Popen kwargs).
        """
        kernel_cmd, kw = self.pre_start_kernel(**kw)

        # launch the kernel subprocess
        self.log.debug("Starting kernel (async): %s", kernel_cmd)
        self.kernel = await self._launch_kernel(kernel_cmd, **kw)
        self.post_start_kernel(**kw)

    async def finish_shutdown(self, waittime=None, pollinterval=0.1):
        """Wait for kernel shutdown, then kill process if it doesn't shutdown.

        This does not send shutdown requests - use :meth:`request_shutdown`
        first.
        """
        if waittime is None:
            waittime = max(self.shutdown_wait_time, 0)
        try:
            await asyncio.wait_for(self._async_wait(pollinterval=pollinterval), timeout=waittime)
        except asyncio.TimeoutError:
            self.log.debug("Kernel is taking too long to finish, killing")
            await self._kill_kernel()
        else:
            # Process is no longer alive, wait and clear
            if self.kernel is not None:
                self.kernel.wait()
                self.kernel = None

    async def shutdown_kernel(self, now=False, restart=False):
        """Attempts to stop the kernel process cleanly.

        This attempts to shutdown the kernels cleanly by:

        1. Sending it a shutdown message over the shell channel.
        2. If that fails, the kernel is shutdown forcibly by sending it
           a signal.

        Parameters
        ----------
        now : bool
            Should the kernel be forcible killed *now*. This skips the
            first, nice shutdown attempt.
        restart: bool
            Will this kernel be restarted after it is shutdown. When this
            is True, connection files will not be cleaned up.
        """
        self.shutting_down = True  # Used by restarter to prevent race condition
        # Stop monitoring for restarting while we shutdown.
        self.stop_restarter()

        if now:
            await self._kill_kernel()
        else:
            self.request_shutdown(restart=restart)
            # Don't send any additional kernel kill messages immediately, to give
            # the kernel a chance to properly execute shutdown actions. Wait for at
            # most 1s, checking every 0.1s.
            await self.finish_shutdown()

        # See comment in KernelManager.shutdown_kernel().
        overrides_cleanup = type(self).cleanup is not AsyncKernelManager.cleanup
        overrides_cleanup_resources = type(self).cleanup_resources is not AsyncKernelManager.cleanup_resources

        if overrides_cleanup and not overrides_cleanup_resources:
            self.cleanup(connection_file=not restart)
        else:
            self.cleanup_resources(restart=restart)

    async def restart_kernel(self, now=False, newports=False, **kw):
        """Restarts a kernel with the arguments that were used to launch it.

        Parameters
        ----------
        now : bool, optional
            If True, the kernel is forcefully restarted *immediately*, without
            having a chance to do any cleanup action.  Otherwise the kernel is
            given 1s to clean up before a forceful restart is issued.

            In all cases the kernel is restarted, the only difference is whether
            it is given a chance to perform a clean shutdown or not.

        newports : bool, optional
            If the old kernel was launched with random ports, this flag decides
            whether the same ports and connection file will be used again.
            If False, the same ports and connection file are used. This is
            the default. If True, new random port numbers are chosen and a
            new connection file is written. It is still possible that the newly
            chosen random port numbers happen to be the same as the old ones.

        `**kw` : optional
            Any options specified here will overwrite those used to launch the
            kernel.
        """
        if self._launch_args is None:
            raise RuntimeError("Cannot restart the kernel. "
                               "No previous call to 'start_kernel'.")
        else:
            # Stop currently running kernel.
            await self.shutdown_kernel(now=now, restart=True)

            if newports:
                self.cleanup_random_ports()

            # Start new kernel.
            self._launch_args.update(kw)
            await self.start_kernel(**self._launch_args)
        return None

    async def _kill_kernel(self):
        """Kill the running kernel.

        This is a private method, callers should use shutdown_kernel(now=True).
        """
        if self.has_kernel:
            # Signal the kernel to terminate (sends SIGKILL on Unix and calls
            # TerminateProcess() on Win32).
            try:
                if hasattr(signal, 'SIGKILL'):
                    await self.signal_kernel(signal.SIGKILL)
                else:
                    self.kernel.kill()
            except OSError as e:
                # In Windows, we will get an Access Denied error if the process
                # has already terminated. Ignore it.
                if sys.platform == 'win32':
                    if e.winerror != 5:
                        raise
                # On Unix, we may get an ESRCH error if the process has already
                # terminated. Ignore it.
                else:
                    from errno import ESRCH
                    if e.errno != ESRCH:
                        raise

            # Wait until the kernel terminates.
            try:
                await asyncio.wait_for(self._async_wait(), timeout=5.0)
            except asyncio.TimeoutError:
                # Wait timed out, just log warning but continue - not much more we can do.
                self.log.warning("Wait for final termination of kernel timed out - continuing...")
                pass
            else:
                # Process is no longer alive, wait and clear
                if self.kernel is not None:
                    self.kernel.wait()
            self.kernel = None

    async def interrupt_kernel(self):
        """Interrupts the kernel by sending it a signal.

        Unlike ``signal_kernel``, this operation is well supported on all
        platforms.
        """
        if self.has_kernel:
            interrupt_mode = self.kernel_spec.interrupt_mode
            if interrupt_mode == 'signal':
                if sys.platform == 'win32':
                    from .win_interrupt import send_interrupt
                    send_interrupt(self.kernel.win32_interrupt_event)
                else:
                    await self.signal_kernel(signal.SIGINT)

            elif interrupt_mode == 'message':
                msg = self.session.msg("interrupt_request", content={})
                self._connect_control_socket()
                self.session.send(self._control_socket, msg)
        else:
            raise RuntimeError("Cannot interrupt kernel. No kernel is running!")

    async def signal_kernel(self, signum):
        """Sends a signal to the process group of the kernel (this
        usually includes the kernel and any subprocesses spawned by
        the kernel).

        Note that since only SIGTERM is supported on Windows, this function is
        only useful on Unix systems.
        """
        if self.has_kernel:
            if hasattr(os, "getpgid") and hasattr(os, "killpg"):
                try:
                    pgid = os.getpgid(self.kernel.pid)
                    os.killpg(pgid, signum)
                    return
                except OSError:
                    pass
            self.kernel.send_signal(signum)
        else:
            raise RuntimeError("Cannot signal kernel. No kernel is running!")

    async def is_alive(self):
        """Is the kernel process still running?"""
        if self.has_kernel:
            if self.kernel.poll() is None:
                return True
            else:
                return False
        else:
            # we don't have a kernel
            return False

    async def _async_wait(self, pollinterval=0.1):
        # Use busy loop at 100ms intervals, polling until the process is
        # not alive.  If we find the process is no longer alive, complete
        # its cleanup via the blocking wait().  Callers are responsible for
        # issuing calls to wait() using a timeout (see _kill_kernel()).
        while await self.is_alive():
            await asyncio.sleep(pollinterval)
示例#16
0
class NotebookApp(JupyterApp):

    name = 'jupyter-notebook'
    version = __version__
    description = """
        The Jupyter HTML Notebook.
        
        This launches a Tornado based HTML Notebook Server that serves up an
        HTML5/Javascript Notebook client.
    """
    examples = _examples
    aliases = aliases
    flags = flags

    classes = [
        KernelManager,
        Session,
        MappingKernelManager,
        ContentsManager,
        FileContentsManager,
        NotebookNotary,
        KernelSpecManager,
    ]
    flags = Dict(flags)
    aliases = Dict(aliases)

    subcommands = dict(list=(NbserverListApp,
                             NbserverListApp.description.splitlines()[0]), )

    _log_formatter_cls = LogFormatter

    def _log_level_default(self):
        return logging.INFO

    def _log_datefmt_default(self):
        """Exclude date from default date format"""
        return "%H:%M:%S"

    def _log_format_default(self):
        """override default log format to include time"""
        return u"%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s]%(end_color)s %(message)s"

    ignore_minified_js = Bool(
        False,
        config=True,
        help=
        'Deprecated: Use minified JS file or not, mainly use during dev to avoid JS recompilation',
    )

    # file to be opened in the notebook server
    file_to_run = Unicode('', config=True)

    # Network related information

    allow_origin = Unicode('',
                           config=True,
                           help="""Set the Access-Control-Allow-Origin header
        
        Use '*' to allow any origin to access your server.
        
        Takes precedence over allow_origin_pat.
        """)

    allow_origin_pat = Unicode(
        '',
        config=True,
        help=
        """Use a regular expression for the Access-Control-Allow-Origin header
        
        Requests from an origin matching the expression will get replies with:
        
            Access-Control-Allow-Origin: origin
        
        where `origin` is the origin of the request.
        
        Ignored if allow_origin is set.
        """)

    allow_credentials = Bool(
        False,
        config=True,
        help="Set the Access-Control-Allow-Credentials: true header")

    allow_root = Bool(
        False,
        config=True,
        help="Whether to allow the user to run the notebook as root.")

    default_url = Unicode('/tree',
                          config=True,
                          help="The default URL to redirect to from `/`")

    ip = Unicode('localhost',
                 config=True,
                 help="The IP address the notebook server will listen on.")

    def _ip_default(self):
        """Return localhost if available, 127.0.0.1 otherwise.
        
        On some (horribly broken) systems, localhost cannot be bound.
        """
        s = socket.socket()
        try:
            s.bind(('localhost', 0))
        except socket.error as e:
            self.log.warning(
                "Cannot bind to localhost, using 127.0.0.1 as default ip\n%s",
                e)
            return '127.0.0.1'
        else:
            s.close()
            return 'localhost'

    def _ip_changed(self, name, old, new):
        if new == u'*': self.ip = u''

    port = Integer(8888,
                   config=True,
                   help="The port the notebook server will listen on.")
    port_retries = Integer(
        50,
        config=True,
        help=
        "The number of additional ports to try if the specified port is not available."
    )

    certfile = Unicode(
        u'',
        config=True,
        help="""The full path to an SSL/TLS certificate file.""")

    keyfile = Unicode(
        u'',
        config=True,
        help="""The full path to a private key file for usage with SSL/TLS.""")

    client_ca = Unicode(
        u'',
        config=True,
        help=
        """The full path to a certificate authority certifificate for SSL/TLS client authentication."""
    )

    cookie_secret_file = Unicode(
        config=True, help="""The file where the cookie secret is stored.""")

    def _cookie_secret_file_default(self):
        return os.path.join(self.runtime_dir, 'notebook_cookie_secret')

    cookie_secret = Bytes(b'',
                          config=True,
                          help="""The random bytes used to secure cookies.
        By default this is a new random number every time you start the Notebook.
        Set it to a value in a config file to enable logins to persist across server sessions.
        
        Note: Cookie secrets should be kept private, do not share config files with
        cookie_secret stored in plaintext (you can read the value from a file).
        """)

    def _cookie_secret_default(self):
        if os.path.exists(self.cookie_secret_file):
            with io.open(self.cookie_secret_file, 'rb') as f:
                return f.read()
        else:
            secret = encodebytes(os.urandom(1024))
            self._write_cookie_secret_file(secret)
            return secret

    def _write_cookie_secret_file(self, secret):
        """write my secret to my secret_file"""
        self.log.info("Writing notebook server cookie secret to %s",
                      self.cookie_secret_file)
        with io.open(self.cookie_secret_file, 'wb') as f:
            f.write(secret)
        try:
            os.chmod(self.cookie_secret_file, 0o600)
        except OSError:
            self.log.warning("Could not set permissions on %s",
                             self.cookie_secret_file)

    password = Unicode(u'',
                       config=True,
                       help="""Hashed password to use for web authentication.

                      To generate, type in a python/IPython shell:

                        from notebook.auth import passwd; passwd()

                      The string should be of the form type:salt:hashed-password.
                      """)

    open_browser = Bool(True,
                        config=True,
                        help="""Whether to open in a browser after starting.
                        The specific browser used is platform dependent and
                        determined by the python standard library `webbrowser`
                        module, unless it is overridden using the --browser
                        (NotebookApp.browser) configuration option.
                        """)

    browser = Unicode(u'',
                      config=True,
                      help="""Specify what command to use to invoke a web
                      browser when opening the notebook. If not specified, the
                      default browser will be determined by the `webbrowser`
                      standard library module, which allows setting of the
                      BROWSER environment variable to override it.
                      """)

    webapp_settings = Dict(config=True,
                           help="DEPRECATED, use tornado_settings")

    def _webapp_settings_changed(self, name, old, new):
        self.log.warning(
            "\n    webapp_settings is deprecated, use tornado_settings.\n")
        self.tornado_settings = new

    tornado_settings = Dict(
        config=True,
        help="Supply overrides for the tornado.web.Application that the "
        "Jupyter notebook uses.")

    ssl_options = Dict(config=True,
                       help="""Supply SSL options for the tornado HTTPServer.
            See the tornado docs for details.""")

    jinja_environment_options = Dict(
        config=True,
        help="Supply extra arguments that will be passed to Jinja environment."
    )

    jinja_template_vars = Dict(
        config=True,
        help="Extra variables to supply to jinja templates when rendering.",
    )

    enable_mathjax = Bool(
        True,
        config=True,
        help="""Whether to enable MathJax for typesetting math/TeX

        MathJax is the javascript library Jupyter uses to render math/LaTeX. It is
        very large, so you may want to disable it if you have a slow internet
        connection, or for offline use of the notebook.

        When disabled, equations etc. will appear as their untransformed TeX source.
        """)

    def _enable_mathjax_changed(self, name, old, new):
        """set mathjax url to empty if mathjax is disabled"""
        if not new:
            self.mathjax_url = u''

    base_url = Unicode('/',
                       config=True,
                       help='''The base URL for the notebook server.

                               Leading and trailing slashes can be omitted,
                               and will automatically be added.
                               ''')

    def _base_url_changed(self, name, old, new):
        if not new.startswith('/'):
            self.base_url = '/' + new
        elif not new.endswith('/'):
            self.base_url = new + '/'

    base_project_url = Unicode('/',
                               config=True,
                               help="""DEPRECATED use base_url""")

    def _base_project_url_changed(self, name, old, new):
        self.log.warning("base_project_url is deprecated, use base_url")
        self.base_url = new

    extra_static_paths = List(
        Unicode(),
        config=True,
        help="""Extra paths to search for serving static files.
        
        This allows adding javascript/css to be available from the notebook server machine,
        or overriding individual files in the IPython""")

    @property
    def static_file_path(self):
        """return extra paths + the default location"""
        return self.extra_static_paths + [DEFAULT_STATIC_FILES_PATH]

    static_custom_path = List(Unicode(),
                              help="""Path to search for custom.js, css""")

    def _static_custom_path_default(self):
        return [
            os.path.join(d, 'custom')
            for d in (self.config_dir, DEFAULT_STATIC_FILES_PATH)
        ]

    extra_template_paths = List(
        Unicode(),
        config=True,
        help="""Extra paths to search for serving jinja templates.

        Can be used to override templates from notebook.templates.""")

    @property
    def template_file_path(self):
        """return extra paths + the default locations"""
        return self.extra_template_paths + DEFAULT_TEMPLATE_PATH_LIST

    extra_nbextensions_path = List(
        Unicode(),
        config=True,
        help="""extra paths to look for Javascript notebook extensions""")

    @property
    def nbextensions_path(self):
        """The path to look for Javascript notebook extensions"""
        path = self.extra_nbextensions_path + jupyter_path('nbextensions')
        # FIXME: remove IPython nbextensions path after a migration period
        try:
            from IPython.paths import get_ipython_dir
        except ImportError:
            pass
        else:
            path.append(os.path.join(get_ipython_dir(), 'nbextensions'))
        return path

    websocket_url = Unicode("",
                            config=True,
                            help="""The base URL for websockets,
        if it differs from the HTTP server (hint: it almost certainly doesn't).
        
        Should be in the form of an HTTP origin: ws[s]://hostname[:port]
        """)
    mathjax_url = Unicode("", config=True, help="""The url for MathJax.js.""")

    def _mathjax_url_default(self):
        if not self.enable_mathjax:
            return u''
        static_url_prefix = self.tornado_settings.get("static_url_prefix",
                                                      "static")
        return url_path_join(static_url_prefix, 'components', 'MathJax',
                             'MathJax.js')

    def _mathjax_url_changed(self, name, old, new):
        if new and not self.enable_mathjax:
            # enable_mathjax=False overrides mathjax_url
            self.mathjax_url = u''
        else:
            self.log.info("Using MathJax: %s", new)

    contents_manager_class = Type(default_value=FileContentsManager,
                                  klass=ContentsManager,
                                  config=True,
                                  help='The notebook manager class to use.')
    kernel_manager_class = Type(default_value=MappingKernelManager,
                                config=True,
                                help='The kernel manager class to use.')
    session_manager_class = Type(default_value=SessionManager,
                                 config=True,
                                 help='The session manager class to use.')

    config_manager_class = Type(default_value=ConfigManager,
                                config=True,
                                help='The config manager class to use')

    kernel_spec_manager = Instance(KernelSpecManager, allow_none=True)

    kernel_spec_manager_class = Type(default_value=KernelSpecManager,
                                     config=True,
                                     help="""
        The kernel spec manager class to use. Should be a subclass
        of `jupyter_client.kernelspec.KernelSpecManager`.

        The Api of KernelSpecManager is provisional and might change
        without warning between this version of Jupyter and the next stable one.
        """)

    login_handler_class = Type(
        default_value=LoginHandler,
        klass=web.RequestHandler,
        config=True,
        help='The login handler class to use.',
    )

    logout_handler_class = Type(
        default_value=LogoutHandler,
        klass=web.RequestHandler,
        config=True,
        help='The logout handler class to use.',
    )

    trust_xheaders = Bool(
        False,
        config=True,
        help=
        ("Whether to trust or not X-Scheme/X-Forwarded-Proto and X-Real-Ip/X-Forwarded-For headers"
         "sent by the upstream reverse proxy. Necessary if the proxy handles SSL"
         ))

    info_file = Unicode()

    def _info_file_default(self):
        info_file = "nbserver-%s.json" % os.getpid()
        return os.path.join(self.runtime_dir, info_file)

    pylab = Unicode('disabled',
                    config=True,
                    help="""
        DISABLED: use %pylab or %matplotlib in the notebook to enable matplotlib.
        """)

    def _pylab_changed(self, name, old, new):
        """when --pylab is specified, display a warning and exit"""
        if new != 'warn':
            backend = ' %s' % new
        else:
            backend = ''
        self.log.error(
            "Support for specifying --pylab on the command line has been removed."
        )
        self.log.error(
            "Please use `%pylab{0}` or `%matplotlib{0}` in the notebook itself."
            .format(backend))
        self.exit(1)

    notebook_dir = Unicode(
        config=True, help="The directory to use for notebooks and kernels.")

    def _notebook_dir_default(self):
        if self.file_to_run:
            return os.path.dirname(os.path.abspath(self.file_to_run))
        else:
            return py3compat.getcwd()

    def _notebook_dir_validate(self, value, trait):
        # Strip any trailing slashes
        value = value.rstrip(os.sep)

        if not os.path.isabs(value):
            # If we receive a non-absolute path, make it absolute.
            value = os.path.abspath(value)
        if not os.path.isdir(value):
            raise TraitError("No such notebook dir: %r" % value)
        return value

    def _notebook_dir_changed(self, name, old, new):
        """Do a bit of validation of the notebook dir."""
        # setting App.notebook_dir implies setting notebook and kernel dirs as well
        self.config.FileContentsManager.root_dir = new
        self.config.MappingKernelManager.root_dir = new

    server_extensions = List(
        Unicode(),
        config=True,
        help=(
            "Python modules to load as notebook server extensions. "
            "This is an experimental API, and may change in future releases."))

    reraise_server_extension_failures = Bool(
        False,
        config=True,
        help="Reraise exceptions encountered loading server extensions?",
    )

    iopub_msg_rate_limit = Float(0,
                                 config=True,
                                 help="""(msg/sec)
        Maximum rate at which messages can be sent on iopub before they are
        limited.""")

    iopub_data_rate_limit = Float(0,
                                  config=True,
                                  help="""(bytes/sec)
        Maximum rate at which messages can be sent on iopub before they are
        limited.""")

    rate_limit_window = Float(1.0,
                              config=True,
                              help="""(sec) Time window used to 
        check the message and data rate limits.""")

    def parse_command_line(self, argv=None):
        super(NotebookApp, self).parse_command_line(argv)

        if self.extra_args:
            arg0 = self.extra_args[0]
            f = os.path.abspath(arg0)
            self.argv.remove(arg0)
            if not os.path.exists(f):
                self.log.critical("No such file or directory: %s", f)
                self.exit(1)

            # Use config here, to ensure that it takes higher priority than
            # anything that comes from the config dirs.
            c = Config()
            if os.path.isdir(f):
                c.NotebookApp.notebook_dir = f
            elif os.path.isfile(f):
                c.NotebookApp.file_to_run = f
            self.update_config(c)

    def init_configurables(self):
        self.kernel_spec_manager = self.kernel_spec_manager_class(
            parent=self, )
        self.kernel_manager = self.kernel_manager_class(
            parent=self,
            log=self.log,
            connection_dir=self.runtime_dir,
            kernel_spec_manager=self.kernel_spec_manager,
        )
        self.contents_manager = self.contents_manager_class(
            parent=self,
            log=self.log,
        )
        self.session_manager = self.session_manager_class(
            parent=self,
            log=self.log,
            kernel_manager=self.kernel_manager,
            contents_manager=self.contents_manager,
        )
        self.config_manager = self.config_manager_class(
            parent=self,
            log=self.log,
            config_dir=os.path.join(self.config_dir, 'nbconfig'),
        )

    def init_logging(self):
        # This prevents double log messages because tornado use a root logger that
        # self.log is a child of. The logging module dipatches log messages to a log
        # and all of its ancenstors until propagate is set to False.
        self.log.propagate = False

        for log in app_log, access_log, gen_log:
            # consistent log output name (NotebookApp instead of tornado.access, etc.)
            log.name = self.log.name
        # hook up tornado 3's loggers to our app handlers
        logger = logging.getLogger('tornado')
        logger.propagate = True
        logger.parent = self.log
        logger.setLevel(self.log.level)

    def init_webapp(self):
        """initialize tornado webapp and httpserver"""
        self.tornado_settings['allow_origin'] = self.allow_origin
        if self.allow_origin_pat:
            self.tornado_settings['allow_origin_pat'] = re.compile(
                self.allow_origin_pat)
        self.tornado_settings['allow_credentials'] = self.allow_credentials
        # ensure default_url starts with base_url
        if not self.default_url.startswith(self.base_url):
            self.default_url = url_path_join(self.base_url, self.default_url)

        self.web_app = NotebookWebApplication(
            self, self.kernel_manager, self.contents_manager,
            self.session_manager, self.kernel_spec_manager,
            self.config_manager, self.log, self.base_url, self.default_url,
            self.tornado_settings, self.jinja_environment_options)
        ssl_options = self.ssl_options
        if self.certfile:
            ssl_options['certfile'] = self.certfile
        if self.keyfile:
            ssl_options['keyfile'] = self.keyfile
        if self.client_ca:
            ssl_options['ca_certs'] = self.client_ca
        if not ssl_options:
            # None indicates no SSL config
            ssl_options = None
        else:
            # SSL may be missing, so only import it if it's to be used
            import ssl
            # Disable SSLv3, since its use is discouraged.
            ssl_options['ssl_version'] = ssl.PROTOCOL_TLSv1
            if ssl_options.get('ca_certs', False):
                ssl_options['cert_reqs'] = ssl.CERT_REQUIRED

        self.login_handler_class.validate_security(self,
                                                   ssl_options=ssl_options)
        self.http_server = httpserver.HTTPServer(self.web_app,
                                                 ssl_options=ssl_options,
                                                 xheaders=self.trust_xheaders)

        success = None
        for port in random_ports(self.port, self.port_retries + 1):
            try:
                self.http_server.listen(port, self.ip)
            except socket.error as e:
                if e.errno == errno.EADDRINUSE:
                    self.log.info(
                        'The port %i is already in use, trying another port.' %
                        port)
                    continue
                elif e.errno in (errno.EACCES,
                                 getattr(errno, 'WSAEACCES', errno.EACCES)):
                    self.log.warning("Permission to listen on port %i denied" %
                                     port)
                    continue
                else:
                    raise
            else:
                self.port = port
                success = True
                break
        if not success:
            self.log.critical(
                'ERROR: the notebook server could not be started because '
                'no available port could be found.')
            self.exit(1)

    @property
    def display_url(self):
        ip = self.ip if self.ip else '[all ip addresses on your system]'
        return self._url(ip)

    @property
    def connection_url(self):
        ip = self.ip if self.ip else 'localhost'
        return self._url(ip)

    def _url(self, ip):
        proto = 'https' if self.certfile else 'http'
        return "%s://%s:%i%s" % (proto, ip, self.port, self.base_url)

    def init_terminals(self):
        try:
            from .terminal import initialize
            initialize(self.web_app, self.notebook_dir, self.connection_url)
            self.web_app.settings['terminals_available'] = True
        except ImportError as e:
            log = self.log.debug if sys.platform == 'win32' else self.log.warning
            log("Terminals not available (error was %s)", e)

    def init_signal(self):
        if not sys.platform.startswith('win') and sys.stdin.isatty():
            signal.signal(signal.SIGINT, self._handle_sigint)
        signal.signal(signal.SIGTERM, self._signal_stop)
        if hasattr(signal, 'SIGUSR1'):
            # Windows doesn't support SIGUSR1
            signal.signal(signal.SIGUSR1, self._signal_info)
        if hasattr(signal, 'SIGINFO'):
            # only on BSD-based systems
            signal.signal(signal.SIGINFO, self._signal_info)

    def _handle_sigint(self, sig, frame):
        """SIGINT handler spawns confirmation dialog"""
        # register more forceful signal handler for ^C^C case
        signal.signal(signal.SIGINT, self._signal_stop)
        # request confirmation dialog in bg thread, to avoid
        # blocking the App
        thread = threading.Thread(target=self._confirm_exit)
        thread.daemon = True
        thread.start()

    def _restore_sigint_handler(self):
        """callback for restoring original SIGINT handler"""
        signal.signal(signal.SIGINT, self._handle_sigint)

    def _confirm_exit(self):
        """confirm shutdown on ^C
        
        A second ^C, or answering 'y' within 5s will cause shutdown,
        otherwise original SIGINT handler will be restored.
        
        This doesn't work on Windows.
        """
        info = self.log.info
        info('interrupted')
        print(self.notebook_info())
        sys.stdout.write("Shutdown this notebook server (y/[n])? ")
        sys.stdout.flush()
        r, w, x = select.select([sys.stdin], [], [], 5)
        if r:
            line = sys.stdin.readline()
            if line.lower().startswith('y') and 'n' not in line.lower():
                self.log.critical("Shutdown confirmed")
                ioloop.IOLoop.current().stop()
                return
        else:
            print("No answer for 5s:", end=' ')
        print("resuming operation...")
        # no answer, or answer is no:
        # set it back to original SIGINT handler
        # use IOLoop.add_callback because signal.signal must be called
        # from main thread
        ioloop.IOLoop.current().add_callback(self._restore_sigint_handler)

    def _signal_stop(self, sig, frame):
        self.log.critical("received signal %s, stopping", sig)
        ioloop.IOLoop.current().stop()

    def _signal_info(self, sig, frame):
        print(self.notebook_info())

    def init_components(self):
        """Check the components submodule, and warn if it's unclean"""
        # TODO: this should still check, but now we use bower, not git submodule
        pass

    def init_server_extensions(self):
        """Load any extensions specified by config.

        Import the module, then call the load_jupyter_server_extension function,
        if one exists.
        
        The extension API is experimental, and may change in future releases.
        """
        for modulename in self.server_extensions:
            try:
                mod = importlib.import_module(modulename)
                func = getattr(mod, 'load_jupyter_server_extension', None)
                if func is not None:
                    func(self)
            except Exception:
                if self.reraise_server_extension_failures:
                    raise
                self.log.warning("Error loading server extension %s",
                                 modulename,
                                 exc_info=True)

    @catch_config_error
    def initialize(self, argv=None):
        super(NotebookApp, self).initialize(argv)
        self.init_logging()
        if self._dispatching:
            return
        self.init_configurables()
        self.init_components()
        self.init_webapp()
        self.init_terminals()
        self.init_signal()
        self.init_server_extensions()

    def cleanup_kernels(self):
        """Shutdown all kernels.
        
        The kernels will shutdown themselves when this process no longer exists,
        but explicit shutdown allows the KernelManagers to cleanup the connection files.
        """
        self.log.info('Shutting down kernels')
        self.kernel_manager.shutdown_all()

    def notebook_info(self):
        "Return the current working directory and the server url information"
        info = self.contents_manager.info_string() + "\n"
        info += "%d active kernels \n" % len(self.kernel_manager._kernels)
        return info + "The Jupyter Notebook is running at: %s" % self.display_url

    def server_info(self):
        """Return a JSONable dict of information about this server."""
        return {
            'url': self.connection_url,
            'hostname': self.ip if self.ip else 'localhost',
            'port': self.port,
            'secure': bool(self.certfile),
            'base_url': self.base_url,
            'notebook_dir': os.path.abspath(self.notebook_dir),
            'pid': os.getpid()
        }

    def write_server_info_file(self):
        """Write the result of server_info() to the JSON file info_file."""
        with open(self.info_file, 'w') as f:
            json.dump(self.server_info(), f, indent=2)

    def remove_server_info_file(self):
        """Remove the nbserver-<pid>.json file created for this server.
        
        Ignores the error raised when the file has already been removed.
        """
        try:
            os.unlink(self.info_file)
        except OSError as e:
            if e.errno != errno.ENOENT:
                raise

    def start(self):
        """ Start the Notebook server app, after initialization
        
        This method takes no arguments so all configuration and initialization
        must be done prior to calling this method."""

        if not self.allow_root:
            # check if we are running as root, and abort if it's not allowed
            try:
                uid = os.geteuid()
            except AttributeError:
                uid = -1  # anything nonzero here, since we can't check UID assume non-root
            if uid == 0:
                self.log.critical(
                    "Running as root is not recommended. Use --allow-root to bypass."
                )
                self.exit(1)

        super(NotebookApp, self).start()

        info = self.log.info
        for line in self.notebook_info().split("\n"):
            info(line)
        info(
            "Use Control-C to stop this server and shut down all kernels (twice to skip confirmation)."
        )

        self.write_server_info_file()

        if self.open_browser or self.file_to_run:
            try:
                browser = webbrowser.get(self.browser or None)
            except webbrowser.Error as e:
                self.log.warning('No web browser found: %s.' % e)
                browser = None

            if self.file_to_run:
                if not os.path.exists(self.file_to_run):
                    self.log.critical("%s does not exist" % self.file_to_run)
                    self.exit(1)

                relpath = os.path.relpath(self.file_to_run, self.notebook_dir)
                uri = url_escape(
                    url_path_join('notebooks', *relpath.split(os.sep)))
            else:
                uri = self.default_url
            if browser:
                b = lambda: browser.open(
                    url_path_join(self.connection_url, uri), new=2)
                threading.Thread(target=b).start()

        self.io_loop = ioloop.IOLoop.current()
        if sys.platform.startswith('win'):
            # add no-op to wake every 5s
            # to handle signals that may be ignored by the inner loop
            pc = ioloop.PeriodicCallback(lambda: None, 5000)
            pc.start()
        try:
            self.io_loop.start()
        except KeyboardInterrupt:
            info("Interrupted...")
        finally:
            self.cleanup_kernels()
            self.remove_server_info_file()

    def stop(self):
        def _stop():
            self.http_server.stop()
            self.io_loop.stop()

        self.io_loop.add_callback(_stop)
示例#17
0
class KernelManager(ConnectionFileMixin):
    """Manages a single kernel in a subprocess on this host.

    This version starts kernels with Popen.
    """

    # The PyZMQ Context to use for communication with the kernel.
    context = Instance(zmq.Context)

    def _context_default(self):
        return zmq.Context.instance()

    # the class to create with our `client` method
    client_class = DottedObjectName(
        'jupyter_client.blocking.BlockingKernelClient')
    client_factory = Type(klass='jupyter_client.KernelClient')

    def _client_factory_default(self):
        return import_item(self.client_class)

    def _client_class_changed(self, name, old, new):
        self.client_factory = import_item(str(new))

    # The kernel process with which the KernelManager is communicating.
    # generally a Popen instance
    kernel = Any()

    kernel_spec_manager = Instance(kernelspec.KernelSpecManager)

    def _kernel_spec_manager_default(self):
        return kernelspec.KernelSpecManager(data_dir=self.data_dir)

    def _kernel_spec_manager_changed(self):
        self._kernel_spec = None

    kernel_name = Unicode(kernelspec.NATIVE_KERNEL_NAME)

    def _kernel_name_changed(self, name, old, new):
        self._kernel_spec = None
        if new == 'python':
            self.kernel_name = kernelspec.NATIVE_KERNEL_NAME

    _kernel_spec = None

    @property
    def kernel_spec(self):
        if self._kernel_spec is None:
            self._kernel_spec = self.kernel_spec_manager.get_kernel_spec(
                self.kernel_name)
        return self._kernel_spec

    kernel_cmd = List(Unicode(),
                      config=True,
                      help="""DEPRECATED: Use kernel_name instead.

        The Popen Command to launch the kernel.
        Override this if you have a custom kernel.
        If kernel_cmd is specified in a configuration file,
        Jupyter does not pass any arguments to the kernel,
        because it cannot make any assumptions about the
        arguments that the kernel understands. In particular,
        this means that the kernel does not receive the
        option --debug if it given on the Jupyter command line.
        """)

    def _kernel_cmd_changed(self, name, old, new):
        warnings.warn("Setting kernel_cmd is deprecated, use kernel_spec to "
                      "start different kernels.")

    @property
    def ipykernel(self):
        return self.kernel_name in {'python', 'python2', 'python3'}

    # Protected traits
    _launch_args = Any()
    _control_socket = Any()

    _restarter = Any()

    autorestart = Bool(True,
                       config=True,
                       help="""Should we autorestart the kernel if it dies.""")

    def __del__(self):
        self._close_control_socket()
        self.cleanup_connection_file()

    #--------------------------------------------------------------------------
    # Kernel restarter
    #--------------------------------------------------------------------------

    def start_restarter(self):
        pass

    def stop_restarter(self):
        pass

    def add_restart_callback(self, callback, event='restart'):
        """register a callback to be called when a kernel is restarted"""
        if self._restarter is None:
            return
        self._restarter.add_callback(callback, event)

    def remove_restart_callback(self, callback, event='restart'):
        """unregister a callback to be called when a kernel is restarted"""
        if self._restarter is None:
            return
        self._restarter.remove_callback(callback, event)

    #--------------------------------------------------------------------------
    # create a Client connected to our Kernel
    #--------------------------------------------------------------------------

    def client(self, **kwargs):
        """Create a client configured to connect to our kernel"""
        kw = {}
        kw.update(self.get_connection_info())
        kw.update(
            dict(
                connection_file=self.connection_file,
                session=self.session,
                parent=self,
            ))

        # add kwargs last, for manual overrides
        kw.update(kwargs)
        return self.client_factory(**kw)

    #--------------------------------------------------------------------------
    # Kernel management
    #--------------------------------------------------------------------------

    def format_kernel_cmd(self, extra_arguments=None):
        """replace templated args (e.g. {connection_file})"""
        extra_arguments = extra_arguments or []
        if self.kernel_cmd:
            cmd = self.kernel_cmd + extra_arguments
        else:
            cmd = self.kernel_spec.argv + extra_arguments

        ns = dict(
            connection_file=self.connection_file,
            prefix=sys.prefix,
        )
        ns.update(self._launch_args)

        pat = re.compile(r'\{([A-Za-z0-9_]+)\}')

        def from_ns(match):
            """Get the key out of ns if it's there, otherwise no change."""
            return ns.get(match.group(1), match.group())

        return [pat.sub(from_ns, arg) for arg in cmd]

    def _launch_kernel(self, kernel_cmd, **kw):
        """actually launch the kernel

        override in a subclass to launch kernel subprocesses differently
        """
        return launch_kernel(kernel_cmd, **kw)

    # Control socket used for polite kernel shutdown

    def _connect_control_socket(self):
        if self._control_socket is None:
            self._control_socket = self.connect_control()
            self._control_socket.linger = 100

    def _close_control_socket(self):
        if self._control_socket is None:
            return
        self._control_socket.close()
        self._control_socket = None

    def start_kernel(self, **kw):
        """Starts a kernel on this host in a separate process.

        If random ports (port=0) are being used, this method must be called
        before the channels are created.

        Parameters
        ----------
        `**kw` : optional
             keyword arguments that are passed down to build the kernel_cmd
             and launching the kernel (e.g. Popen kwargs).
        """
        if self.transport == 'tcp' and not is_local_ip(self.ip):
            raise RuntimeError(
                "Can only launch a kernel on a local interface. "
                "Make sure that the '*_address' attributes are "
                "configured properly. "
                "Currently valid addresses are: %s" % local_ips())

        # write connection file / get default ports
        self.write_connection_file()

        # save kwargs for use in restart
        self._launch_args = kw.copy()
        # build the Popen cmd
        extra_arguments = kw.pop('extra_arguments', [])
        kernel_cmd = self.format_kernel_cmd(extra_arguments=extra_arguments)
        env = os.environ.copy()
        # Don't allow PYTHONEXECUTABLE to be passed to kernel process.
        # If set, it can bork all the things.
        env.pop('PYTHONEXECUTABLE', None)
        if not self.kernel_cmd:
            # If kernel_cmd has been set manually, don't refer to a kernel spec
            # Environment variables from kernel spec are added to os.environ
            env.update(self.kernel_spec.env or {})

        # launch the kernel subprocess
        self.log.debug("Starting kernel: %s", kernel_cmd)
        self.kernel = self._launch_kernel(kernel_cmd, env=env, **kw)
        self.start_restarter()
        self._connect_control_socket()

    def request_shutdown(self, restart=False):
        """Send a shutdown request via control channel
        """
        content = dict(restart=restart)
        msg = self.session.msg("shutdown_request", content=content)
        self.session.send(self._control_socket, msg)

    def finish_shutdown(self, waittime=1, pollinterval=0.1):
        """Wait for kernel shutdown, then kill process if it doesn't shutdown.

        This does not send shutdown requests - use :meth:`request_shutdown`
        first.
        """
        for i in range(int(waittime / pollinterval)):
            if self.is_alive():
                time.sleep(pollinterval)
            else:
                break
        else:
            # OK, we've waited long enough.
            if self.has_kernel:
                self._kill_kernel()

    def cleanup(self, connection_file=True):
        """Clean up resources when the kernel is shut down"""
        if connection_file:
            self.cleanup_connection_file()

        self.cleanup_ipc_files()
        self._close_control_socket()

    def shutdown_kernel(self, now=False, restart=False):
        """Attempts to the stop the kernel process cleanly.

        This attempts to shutdown the kernels cleanly by:

        1. Sending it a shutdown message over the shell channel.
        2. If that fails, the kernel is shutdown forcibly by sending it
           a signal.

        Parameters
        ----------
        now : bool
            Should the kernel be forcible killed *now*. This skips the
            first, nice shutdown attempt.
        restart: bool
            Will this kernel be restarted after it is shutdown. When this
            is True, connection files will not be cleaned up.
        """
        # Stop monitoring for restarting while we shutdown.
        self.stop_restarter()

        if now:
            self._kill_kernel()
        else:
            self.request_shutdown(restart=restart)
            # Don't send any additional kernel kill messages immediately, to give
            # the kernel a chance to properly execute shutdown actions. Wait for at
            # most 1s, checking every 0.1s.
            self.finish_shutdown()

        self.cleanup(connection_file=not restart)

    def restart_kernel(self, now=False, **kw):
        """Restarts a kernel with the arguments that were used to launch it.

        If the old kernel was launched with random ports, the same ports will be
        used for the new kernel. The same connection file is used again.

        Parameters
        ----------
        now : bool, optional
            If True, the kernel is forcefully restarted *immediately*, without
            having a chance to do any cleanup action.  Otherwise the kernel is
            given 1s to clean up before a forceful restart is issued.

            In all cases the kernel is restarted, the only difference is whether
            it is given a chance to perform a clean shutdown or not.

        `**kw` : optional
            Any options specified here will overwrite those used to launch the
            kernel.
        """
        if self._launch_args is None:
            raise RuntimeError("Cannot restart the kernel. "
                               "No previous call to 'start_kernel'.")
        else:
            # Stop currently running kernel.
            self.shutdown_kernel(now=now, restart=True)

            # Start new kernel.
            self._launch_args.update(kw)
            self.start_kernel(**self._launch_args)

    @property
    def has_kernel(self):
        """Has a kernel been started that we are managing."""
        return self.kernel is not None

    def _kill_kernel(self):
        """Kill the running kernel.

        This is a private method, callers should use shutdown_kernel(now=True).
        """
        if self.has_kernel:

            # Signal the kernel to terminate (sends SIGKILL on Unix and calls
            # TerminateProcess() on Win32).
            try:
                self.kernel.kill()
            except OSError as e:
                # In Windows, we will get an Access Denied error if the process
                # has already terminated. Ignore it.
                if sys.platform == 'win32':
                    if e.winerror != 5:
                        raise
                # On Unix, we may get an ESRCH error if the process has already
                # terminated. Ignore it.
                else:
                    from errno import ESRCH
                    if e.errno != ESRCH:
                        raise

            # Block until the kernel terminates.
            self.kernel.wait()
            self.kernel = None
        else:
            raise RuntimeError("Cannot kill kernel. No kernel is running!")

    def interrupt_kernel(self):
        """Interrupts the kernel by sending it a signal.

        Unlike ``signal_kernel``, this operation is well supported on all
        platforms.
        """
        if self.has_kernel:
            if sys.platform == 'win32':
                from .win_interrupt import send_interrupt
                send_interrupt(self.kernel.win32_interrupt_event)
            else:
                self.signal_kernel(signal.SIGINT)
        else:
            raise RuntimeError(
                "Cannot interrupt kernel. No kernel is running!")

    def signal_kernel(self, signum):
        """Sends a signal to the process group of the kernel (this
        usually includes the kernel and any subprocesses spawned by
        the kernel).

        Note that since only SIGTERM is supported on Windows, this function is
        only useful on Unix systems.
        """
        if self.has_kernel:
            if hasattr(os, "getpgid") and hasattr(os, "killpg"):
                try:
                    pgid = os.getpgid(self.kernel.pid)
                    os.killpg(pgid, signum)
                    return
                except OSError:
                    pass
            self.kernel.send_signal(signum)
        else:
            raise RuntimeError("Cannot signal kernel. No kernel is running!")

    def is_alive(self):
        """Is the kernel process still running?"""
        if self.has_kernel:
            if self.kernel.poll() is None:
                return True
            else:
                return False
        else:
            # we don't have a kernel
            return False
示例#18
0
class JupyterConsoleApp(ConnectionFileMixin):
    name = 'jupyter-console-mixin'

    description = """
        The Jupyter Console Mixin.
        
        This class contains the common portions of console client (QtConsole,
        ZMQ-based terminal console, etc).  It is not a full console, in that
        launched terminal subprocesses will not be able to accept input.
        
        The Console using this mixing supports various extra features beyond
        the single-process Terminal IPython shell, such as connecting to
        existing kernel, via:
        
            jupyter console <appname> --existing
        
        as well as tunnel via SSH
        
    """

    classes = classes
    flags = Dict(flags)
    aliases = Dict(aliases)
    kernel_manager_class = Type(default_value=KernelManager,
                                config=True,
                                help='The kernel manager class to use.')
    kernel_client_class = BlockingKernelClient

    kernel_argv = List(Unicode())

    # connection info:

    sshserver = Unicode(
        '',
        config=True,
        help="""The SSH server to use to connect to the kernel.""")
    sshkey = Unicode(
        '',
        config=True,
        help="""Path to the ssh key to use for logging in to the ssh server."""
    )

    def _connection_file_default(self):
        return 'kernel-%i.json' % os.getpid()

    existing = CUnicode('',
                        config=True,
                        help="""Connect to an already running kernel""")

    kernel_name = Unicode('python',
                          config=True,
                          help="""The name of the default kernel to start.""")

    confirm_exit = CBool(
        True,
        config=True,
        help="""
        Set to display confirmation dialog on exit. You can always use 'exit' or 'quit',
        to force a direct exit without any confirmation.""",
    )

    def build_kernel_argv(self, argv=None):
        """build argv to be passed to kernel subprocess
        
        Override in subclasses if any args should be passed to the kernel
        """
        self.kernel_argv = self.extra_args

    def init_connection_file(self):
        """find the connection file, and load the info if found.
        
        The current working directory and the current profile's security
        directory will be searched for the file if it is not given by
        absolute path.
        
        When attempting to connect to an existing kernel and the `--existing`
        argument does not match an existing file, it will be interpreted as a
        fileglob, and the matching file in the current profile's security dir
        with the latest access time will be used.
        
        After this method is called, self.connection_file contains the *full path*
        to the connection file, never just its name.
        """
        if self.existing:
            try:
                cf = find_connection_file(self.existing,
                                          ['.', self.runtime_dir])
            except Exception:
                self.log.critical(
                    "Could not find existing kernel connection file %s",
                    self.existing)
                self.exit(1)
            self.log.debug("Connecting to existing kernel: %s" % cf)
            self.connection_file = cf
        else:
            # not existing, check if we are going to write the file
            # and ensure that self.connection_file is a full path, not just the shortname
            try:
                cf = find_connection_file(self.connection_file,
                                          [self.runtime_dir])
            except Exception:
                # file might not exist
                if self.connection_file == os.path.basename(
                        self.connection_file):
                    # just shortname, put it in security dir
                    cf = os.path.join(self.runtime_dir, self.connection_file)
                else:
                    cf = self.connection_file
                self.connection_file = cf
        try:
            self.connection_file = _filefind(self.connection_file,
                                             [".", self.runtime_dir])
        except IOError:
            self.log.debug("Connection File not found: %s",
                           self.connection_file)
            return

        # should load_connection_file only be used for existing?
        # as it is now, this allows reusing ports if an existing
        # file is requested
        try:
            self.load_connection_file()
        except Exception:
            self.log.error("Failed to load connection file: %r",
                           self.connection_file,
                           exc_info=True)
            self.exit(1)

    def init_ssh(self):
        """set up ssh tunnels, if needed."""
        if not self.existing or (not self.sshserver and not self.sshkey):
            return
        self.load_connection_file()

        transport = self.transport
        ip = self.ip

        if transport != 'tcp':
            self.log.error("Can only use ssh tunnels with TCP sockets, not %s",
                           transport)
            sys.exit(-1)

        if self.sshkey and not self.sshserver:
            # specifying just the key implies that we are connecting directly
            self.sshserver = ip
            ip = localhost()

        # build connection dict for tunnels:
        info = dict(ip=ip,
                    shell_port=self.shell_port,
                    iopub_port=self.iopub_port,
                    stdin_port=self.stdin_port,
                    hb_port=self.hb_port,
                    control_port=self.control_port)

        self.log.info("Forwarding connections to %s via %s" %
                      (ip, self.sshserver))

        # tunnels return a new set of ports, which will be on localhost:
        self.ip = localhost()
        try:
            newports = tunnel_to_kernel(info, self.sshserver, self.sshkey)
        except:
            # even catch KeyboardInterrupt
            self.log.error("Could not setup tunnels", exc_info=True)
            self.exit(1)

        self.shell_port, self.iopub_port, self.stdin_port, self.hb_port, self.control_port = newports

        cf = self.connection_file
        root, ext = os.path.splitext(cf)
        self.connection_file = root + '-ssh' + ext
        self.write_connection_file()  # write the new connection file
        self.log.info("To connect another client via this tunnel, use:")
        self.log.info("--existing %s" % os.path.basename(self.connection_file))

    def _new_connection_file(self):
        cf = ''
        while not cf:
            # we don't need a 128b id to distinguish kernels, use more readable
            # 48b node segment (12 hex chars).  Users running more than 32k simultaneous
            # kernels can subclass.
            ident = str(uuid.uuid4()).split('-')[-1]
            cf = os.path.join(self.runtime_dir, 'kernel-%s.json' % ident)
            # only keep if it's actually new.  Protect against unlikely collision
            # in 48b random search space
            cf = cf if not os.path.exists(cf) else ''
        return cf

    def init_kernel_manager(self):
        # Don't let Qt or ZMQ swallow KeyboardInterupts.
        if self.existing:
            self.kernel_manager = None
            return
        signal.signal(signal.SIGINT, signal.SIG_DFL)

        # Create a KernelManager and start a kernel.
        try:
            self.kernel_manager = self.kernel_manager_class(
                ip=self.ip,
                session=self.session,
                transport=self.transport,
                shell_port=self.shell_port,
                iopub_port=self.iopub_port,
                stdin_port=self.stdin_port,
                hb_port=self.hb_port,
                control_port=self.control_port,
                connection_file=self.connection_file,
                kernel_name=self.kernel_name,
                parent=self,
                data_dir=self.data_dir,
            )
        except NoSuchKernel:
            self.log.critical("Could not find kernel %s", self.kernel_name)
            self.exit(1)

        self.kernel_manager.client_factory = self.kernel_client_class
        kwargs = {}
        kwargs['extra_arguments'] = self.kernel_argv
        self.kernel_manager.start_kernel(**kwargs)
        atexit.register(self.kernel_manager.cleanup_ipc_files)

        if self.sshserver:
            # ssh, write new connection file
            self.kernel_manager.write_connection_file()

        # in case KM defaults / ssh writing changes things:
        km = self.kernel_manager
        self.shell_port = km.shell_port
        self.iopub_port = km.iopub_port
        self.stdin_port = km.stdin_port
        self.hb_port = km.hb_port
        self.control_port = km.control_port
        self.connection_file = km.connection_file

        atexit.register(self.kernel_manager.cleanup_connection_file)

    def init_kernel_client(self):
        if self.kernel_manager is not None:
            self.kernel_client = self.kernel_manager.client()
        else:
            self.kernel_client = self.kernel_client_class(
                session=self.session,
                ip=self.ip,
                transport=self.transport,
                shell_port=self.shell_port,
                iopub_port=self.iopub_port,
                stdin_port=self.stdin_port,
                hb_port=self.hb_port,
                control_port=self.control_port,
                connection_file=self.connection_file,
                parent=self,
            )

        self.kernel_client.start_channels()

    def initialize(self, argv=None):
        """
        Classes which mix this class in should call:
               JupyterConsoleApp.initialize(self,argv)
        """
        if self._dispatching:
            return
        self.init_connection_file()
        self.init_ssh()
        self.init_kernel_manager()
        self.init_kernel_client()
class WrapSpawner(Spawner):

    # Grab this from constructor args in case some Spawner ever wants it
    config = Any()

    child_class = Type(
        KubeSpawner,
        Spawner,
        config=True,
        help="""The class to wrap for spawning single-user servers.
                Should be a subclass of Spawner.
                """)

    child_config = Dict(
        default_value={},
        config=True,
        help="Dictionary of config values to apply to wrapped spawner class.")

    child_state = Dict(default_value={})

    child_spawner = Instance(Spawner, allow_none=True)

    def construct_child(self):
        if self.child_spawner is None:
            self.child_spawner = self.child_class(
                user=self.user,
                db=self.db,
                hub=self.hub,
                authenticator=self.authenticator,
                config=self.config,
                **self.child_config)
            # initial state will always be wrong since it will see *our* state
            self.child_spawner.clear_state()
            if self.child_state:
                self.child_spawner.load_state(self.child_state)
            self.child_spawner.api_token = self.api_token
        return self.child_spawner

    def load_child_class(self, state):
        # Subclasses must arrange for correct child_class setting from load_state
        pass

    def load_state(self, state):
        super().load_state(state)
        self.load_child_class(state)
        self.child_config.update(state.get('child_conf', {}))
        self.child_state = state.get('child_state', {})
        self.construct_child()

    def get_state(self):
        state = super().get_state()
        state['child_conf'] = self.child_config
        if self.child_spawner:
            self.child_state = state[
                'child_state'] = self.child_spawner.get_state()
        return state

    def clear_state(self):
        super().clear_state()
        if self.child_spawner:
            self.child_spawner.clear_state()
        self.child_state = {}
        self.child_spawner = None

    # proxy functions for start/poll/stop
    # pass back the child's Future, or create a dummy if needed

    def start(self):
        if not self.child_spawner:
            self.construct_child()
        return self.child_spawner.start()

    def stop(self, now=False):
        if self.child_spawner:
            return self.child_spawner.stop(now)
        else:
            return _yield_val()

    def poll(self):
        if self.child_spawner:
            return self.child_spawner.poll()
        else:
            return _yield_val(1)
示例#20
0
class JupyterHub(Application):
    """An Application for starting a Multi-User Jupyter Notebook server."""
    name = 'jupyterhub'
    version = jupyterhub.__version__

    description = """Start a multi-user Jupyter Notebook server

    Spawns a configurable-http-proxy and multi-user Hub,
    which authenticates users and spawns single-user Notebook servers
    on behalf of users.
    """

    examples = """

    generate default config file:

        jupyterhub --generate-config -f /etc/jupyterhub/jupyterhub.py

    spawn the server on 10.0.1.2:443 with https:

        jupyterhub --ip 10.0.1.2 --port 443 --ssl-key my_ssl.key --ssl-cert my_ssl.cert
    """

    aliases = Dict(aliases)
    flags = Dict(flags)

    subcommands = {
        'token': (NewToken, "Generate an API token for a user"),
        'upgrade-db':
        (UpgradeDB,
         "Upgrade your JupyterHub state database to the current version."),
    }

    classes = List([
        Spawner,
        LocalProcessSpawner,
        Authenticator,
        PAMAuthenticator,
    ])

    load_groups = Dict(
        List(Unicode()),
        help="""Dict of 'group': ['usernames'] to load at startup.
        
        This strictly *adds* groups and users to groups.
        
        Loading one set of groups, then starting JupyterHub again with a different
        set will not remove users or groups from previous launches.
        That must be done through the API.
        """).tag(config=True)

    config_file = Unicode(
        'jupyterhub_config.py',
        help="The config file to load",
    ).tag(config=True)
    generate_config = Bool(
        False,
        help="Generate default config file",
    ).tag(config=True)
    answer_yes = Bool(
        False,
        help="Answer yes to any questions (e.g. confirm overwrite)").tag(
            config=True)
    pid_file = Unicode('',
                       help="""File to write PID
        Useful for daemonizing jupyterhub.
        """).tag(config=True)
    cookie_max_age_days = Float(
        14,
        help="""Number of days for a login cookie to be valid.
        Default is two weeks.
        """).tag(config=True)
    last_activity_interval = Integer(
        300,
        help="Interval (in seconds) at which to update last-activity timestamps."
    ).tag(config=True)
    proxy_check_interval = Integer(
        30,
        help="Interval (in seconds) at which to check if the proxy is running."
    ).tag(config=True)

    data_files_path = Unicode(
        DATA_FILES_PATH,
        help=
        "The location of jupyterhub data files (e.g. /usr/local/share/jupyter/hub)"
    ).tag(config=True)

    template_paths = List(
        help="Paths to search for jinja templates.", ).tag(config=True)

    @default('template_paths')
    def _template_paths_default(self):
        return [os.path.join(self.data_files_path, 'templates')]

    confirm_no_ssl = Bool(
        False,
        help="""Confirm that JupyterHub should be run without SSL.
        This is **NOT RECOMMENDED** unless SSL termination is being handled by another layer.
        """).tag(config=True)
    ssl_key = Unicode(
        '',
        help="""Path to SSL key file for the public facing interface of the proxy

        Use with ssl_cert
        """).tag(config=True)
    ssl_cert = Unicode(
        '',
        help=
        """Path to SSL certificate file for the public facing interface of the proxy

        Use with ssl_key
        """).tag(config=True)
    ip = Unicode(
        '',
        help="The public facing ip of the whole application (the proxy)").tag(
            config=True)

    subdomain_host = Unicode(
        '',
        help="""Run single-user servers on subdomains of this host.

        This should be the full https://hub.domain.tld[:port]

        Provides additional cross-site protections for javascript served by single-user servers.

        Requires <username>.hub.domain.tld to resolve to the same host as hub.domain.tld.

        In general, this is most easily achieved with wildcard DNS.

        When using SSL (i.e. always) this also requires a wildcard SSL certificate.
        """).tag(config=True)

    def _subdomain_host_changed(self, name, old, new):
        if new and '://' not in new:
            # host should include '://'
            # if not specified, assume https: You have to be really explicit about HTTP!
            self.subdomain_host = 'https://' + new

    port = Integer(8000,
                   help="The public facing port of the proxy").tag(config=True)
    base_url = URLPrefix(
        '/', help="The base URL of the entire application").tag(config=True)
    logo_file = Unicode(
        '',
        help=
        "Specify path to a logo image to override the Jupyter logo in the banner."
    ).tag(config=True)

    @default('logo_file')
    def _logo_file_default(self):
        return os.path.join(self.data_files_path, 'static', 'images',
                            'jupyter.png')

    jinja_environment_options = Dict(
        help="Supply extra arguments that will be passed to Jinja environment."
    ).tag(config=True)

    proxy_cmd = Command('configurable-http-proxy',
                        help="""The command to start the http proxy.

        Only override if configurable-http-proxy is not on your PATH
        """).tag(config=True)
    debug_proxy = Bool(
        False,
        help="show debug output in configurable-http-proxy").tag(config=True)
    proxy_auth_token = Unicode(help="""The Proxy Auth token.

        Loaded from the CONFIGPROXY_AUTH_TOKEN env variable by default.
        """).tag(config=True)

    @default('proxy_auth_token')
    def _proxy_auth_token_default(self):
        token = os.environ.get('CONFIGPROXY_AUTH_TOKEN', None)
        if not token:
            self.log.warning('\n'.join([
                "",
                "Generating CONFIGPROXY_AUTH_TOKEN. Restarting the Hub will require restarting the proxy.",
                "Set CONFIGPROXY_AUTH_TOKEN env or JupyterHub.proxy_auth_token config to avoid this message.",
                "",
            ]))
            token = orm.new_token()
        return token

    proxy_api_ip = Unicode(
        '127.0.0.1', help="The ip for the proxy API handlers").tag(config=True)
    proxy_api_port = Integer(help="The port for the proxy API handlers").tag(
        config=True)

    @default('proxy_api_port')
    def _proxy_api_port_default(self):
        return self.port + 1

    hub_port = Integer(8081, help="The port for this process").tag(config=True)
    hub_ip = Unicode('127.0.0.1',
                     help="The ip for this process").tag(config=True)
    hub_prefix = URLPrefix(
        '/hub/', help="The prefix for the hub server.  Always /base_url/hub/")

    @default('hub_prefix')
    def _hub_prefix_default(self):
        return url_path_join(self.base_url, '/hub/')

    @observe('base_url')
    def _update_hub_prefix(self, change):
        """add base URL to hub prefix"""
        base_url = change['new']
        self.hub_prefix = self._hub_prefix_default()

    cookie_secret = Bytes(help="""The cookie secret to use to encrypt cookies.

        Loaded from the JPY_COOKIE_SECRET env variable by default.
        """).tag(
        config=True,
        env='JPY_COOKIE_SECRET',
    )

    cookie_secret_file = Unicode(
        'jupyterhub_cookie_secret',
        help="""File in which to store the cookie secret.""").tag(config=True)

    api_tokens = Dict(
        Unicode(),
        help="""Dict of token:username to be loaded into the database.

        Allows ahead-of-time generation of API tokens for use by services.
        """).tag(config=True)

    authenticator_class = Type(PAMAuthenticator,
                               Authenticator,
                               help="""Class for authenticating users.

        This should be a class with the following form:

        - constructor takes one kwarg: `config`, the IPython config object.

        - is a tornado.gen.coroutine
        - returns username on success, None on failure
        - takes two arguments: (handler, data),
          where `handler` is the calling web.RequestHandler,
          and `data` is the POST form data from the login page.
        """).tag(config=True)

    authenticator = Instance(Authenticator)

    @default('authenticator')
    def _authenticator_default(self):
        return self.authenticator_class(parent=self, db=self.db)

    # class for spawning single-user servers
    spawner_class = Type(
        LocalProcessSpawner,
        Spawner,
        help="""The class to use for spawning single-user servers.

        Should be a subclass of Spawner.
        """).tag(config=True)

    db_url = Unicode(
        'sqlite:///jupyterhub.sqlite',
        help="url for the database. e.g. `sqlite:///jupyterhub.sqlite`").tag(
            config=True)

    @observe('db_url')
    def _db_url_changed(self, change):
        new = change['new']
        if '://' not in new:
            # assume sqlite, if given as a plain filename
            self.db_url = 'sqlite:///%s' % new

    db_kwargs = Dict(
        help="""Include any kwargs to pass to the database connection.
        See sqlalchemy.create_engine for details.
        """).tag(config=True)

    reset_db = Bool(False,
                    help="Purge and reset the database.").tag(config=True)
    debug_db = Bool(
        False,
        help="log all database transactions. This has A LOT of output").tag(
            config=True)
    session_factory = Any()

    users = Instance(UserDict)

    @default('users')
    def _users_default(self):
        assert self.tornado_settings
        return UserDict(db_factory=lambda: self.db,
                        settings=self.tornado_settings)

    admin_access = Bool(
        False,
        help="""Grant admin users permission to access single-user servers.

        Users should be properly informed if this is enabled.
        """).tag(config=True)
    admin_users = Set(
        help="""DEPRECATED, use Authenticator.admin_users instead.""").tag(
            config=True)

    tornado_settings = Dict(
        help="Extra settings overrides to pass to the tornado application."
    ).tag(config=True)

    cleanup_servers = Bool(
        True,
        help="""Whether to shutdown single-user servers when the Hub shuts down.

        Disable if you want to be able to teardown the Hub while leaving the single-user servers running.

        If both this and cleanup_proxy are False, sending SIGINT to the Hub will
        only shutdown the Hub, leaving everything else running.

        The Hub should be able to resume from database state.
        """).tag(config=True)

    cleanup_proxy = Bool(
        True,
        help="""Whether to shutdown the proxy when the Hub shuts down.

        Disable if you want to be able to teardown the Hub while leaving the proxy running.

        Only valid if the proxy was starting by the Hub process.

        If both this and cleanup_servers are False, sending SIGINT to the Hub will
        only shutdown the Hub, leaving everything else running.

        The Hub should be able to resume from database state.
        """).tag(config=True)

    statsd_host = Unicode(help="Host to send statds metrics to").tag(
        config=True)

    statsd_port = Integer(
        8125, help="Port on which to send statsd metrics about the hub").tag(
            config=True)

    statsd_prefix = Unicode(
        'jupyterhub',
        help="Prefix to use for all metrics sent by jupyterhub to statsd").tag(
            config=True)

    handlers = List()

    _log_formatter_cls = CoroutineLogFormatter
    http_server = None
    proxy_process = None
    io_loop = None

    @default('log_level')
    def _log_level_default(self):
        return logging.INFO

    @default('log_datefmt')
    def _log_datefmt_default(self):
        """Exclude date from default date format"""
        return "%Y-%m-%d %H:%M:%S"

    @default('log_format')
    def _log_format_default(self):
        """override default log format to include time"""
        return "%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s %(module)s:%(lineno)d]%(end_color)s %(message)s"

    extra_log_file = Unicode(help="""Send JupyterHub's logs to this file.

        This will *only* include the logs of the Hub itself,
        not the logs of the proxy or any single-user servers.
        """).tag(config=True)
    extra_log_handlers = List(
        Instance(logging.Handler),
        help="Extra log handlers to set on JupyterHub logger",
    ).tag(config=True)

    statsd = Any(
        allow_none=False,
        help=
        "The statsd client, if any. A mock will be used if we aren't using statsd"
    )

    @default('statsd')
    def _statsd(self):
        if self.statsd_host:
            import statsd
            client = statsd.StatsClient(self.statsd_host, self.statsd_port,
                                        self.statsd_prefix)
            return client
        else:
            # return an empty mock object!
            return EmptyClass()

    def init_logging(self):
        # This prevents double log messages because tornado use a root logger that
        # self.log is a child of. The logging module dipatches log messages to a log
        # and all of its ancenstors until propagate is set to False.
        self.log.propagate = False

        if self.extra_log_file:
            self.extra_log_handlers.append(
                logging.FileHandler(self.extra_log_file))

        _formatter = self._log_formatter_cls(
            fmt=self.log_format,
            datefmt=self.log_datefmt,
        )
        for handler in self.extra_log_handlers:
            if handler.formatter is None:
                handler.setFormatter(_formatter)
            self.log.addHandler(handler)

        # hook up tornado 3's loggers to our app handlers
        for log in (app_log, access_log, gen_log):
            # ensure all log statements identify the application they come from
            log.name = self.log.name
        logger = logging.getLogger('tornado')
        logger.propagate = True
        logger.parent = self.log
        logger.setLevel(self.log.level)

    def init_ports(self):
        if self.hub_port == self.port:
            raise TraitError(
                "The hub and proxy cannot both listen on port %i" % self.port)
        if self.hub_port == self.proxy_api_port:
            raise TraitError(
                "The hub and proxy API cannot both listen on port %i" %
                self.hub_port)
        if self.proxy_api_port == self.port:
            raise TraitError(
                "The proxy's public and API ports cannot both be %i" %
                self.port)

    @staticmethod
    def add_url_prefix(prefix, handlers):
        """add a url prefix to handlers"""
        for i, tup in enumerate(handlers):
            lis = list(tup)
            lis[0] = url_path_join(prefix, tup[0])
            handlers[i] = tuple(lis)
        return handlers

    def init_handlers(self):
        h = []
        # load handlers from the authenticator
        h.extend(self.authenticator.get_handlers(self))
        # set default handlers
        h.extend(handlers.default_handlers)
        h.extend(apihandlers.default_handlers)

        h.append((r'/logo', LogoHandler, {'path': self.logo_file}))
        self.handlers = self.add_url_prefix(self.hub_prefix, h)
        # some extra handlers, outside hub_prefix
        self.handlers.extend([
            (r"%s" % self.hub_prefix.rstrip('/'), web.RedirectHandler, {
                "url": self.hub_prefix,
                "permanent": False,
            }),
            (r"(?!%s).*" % self.hub_prefix, handlers.PrefixRedirectHandler),
            (r'(.*)', handlers.Template404),
        ])

    def _check_db_path(self, path):
        """More informative log messages for failed filesystem access"""
        path = os.path.abspath(path)
        parent, fname = os.path.split(path)
        user = getuser()
        if not os.path.isdir(parent):
            self.log.error("Directory %s does not exist", parent)
        if os.path.exists(parent) and not os.access(parent, os.W_OK):
            self.log.error("%s cannot create files in %s", user, parent)
        if os.path.exists(path) and not os.access(path, os.W_OK):
            self.log.error("%s cannot edit %s", user, path)

    def init_secrets(self):
        trait_name = 'cookie_secret'
        trait = self.traits()[trait_name]
        env_name = trait.metadata.get('env')
        secret_file = os.path.abspath(
            os.path.expanduser(self.cookie_secret_file))
        secret = self.cookie_secret
        secret_from = 'config'
        # load priority: 1. config, 2. env, 3. file
        secret_env = os.environ.get(env_name)
        if not secret and secret_env:
            secret_from = 'env'
            self.log.info("Loading %s from env[%s]", trait_name, env_name)
            secret = binascii.a2b_hex(secret_env)
        if not secret and os.path.exists(secret_file):
            secret_from = 'file'
            self.log.info("Loading %s from %s", trait_name, secret_file)
            try:
                perm = os.stat(secret_file).st_mode
                if perm & 0o07:
                    raise ValueError(
                        "cookie_secret_file can be read or written by anybody")
                with open(secret_file) as f:
                    b64_secret = f.read()
                secret = binascii.a2b_base64(b64_secret)
            except Exception as e:
                self.log.error(
                    "Refusing to run JupyterHub with invalid cookie_secret_file. "
                    "%s error was: %s", secret_file, e)
                self.exit(1)
        if not secret:
            secret_from = 'new'
            self.log.debug("Generating new %s", trait_name)
            secret = os.urandom(SECRET_BYTES)

        if secret_file and secret_from == 'new':
            # if we generated a new secret, store it in the secret_file
            self.log.info("Writing %s to %s", trait_name, secret_file)
            b64_secret = binascii.b2a_base64(secret).decode('ascii')
            with open(secret_file, 'w') as f:
                f.write(b64_secret)
            try:
                os.chmod(secret_file, 0o600)
            except OSError:
                self.log.warning("Failed to set permissions on %s",
                                 secret_file)
        # store the loaded trait value
        self.cookie_secret = secret

    # thread-local storage of db objects
    _local = Instance(threading.local, ())

    @property
    def db(self):
        if not hasattr(self._local, 'db'):
            self._local.db = scoped_session(self.session_factory)()
        return self._local.db

    @property
    def hub(self):
        if not getattr(self._local, 'hub', None):
            q = self.db.query(orm.Hub)
            assert q.count() <= 1
            self._local.hub = q.first()
            if self.subdomain_host and self._local.hub:
                self._local.hub.host = self.subdomain_host
        return self._local.hub

    @hub.setter
    def hub(self, hub):
        self._local.hub = hub
        if hub and self.subdomain_host:
            hub.host = self.subdomain_host

    @property
    def proxy(self):
        if not getattr(self._local, 'proxy', None):
            q = self.db.query(orm.Proxy)
            assert q.count() <= 1
            p = self._local.proxy = q.first()
            if p:
                p.auth_token = self.proxy_auth_token
        return self._local.proxy

    @proxy.setter
    def proxy(self, proxy):
        self._local.proxy = proxy

    def init_db(self):
        """Create the database connection"""
        self.log.debug("Connecting to db: %s", self.db_url)
        try:
            self.session_factory = orm.new_session_factory(self.db_url,
                                                           reset=self.reset_db,
                                                           echo=self.debug_db,
                                                           **self.db_kwargs)
            # trigger constructing thread local db property
            _ = self.db
        except OperationalError as e:
            self.log.error("Failed to connect to db: %s", self.db_url)
            self.log.debug("Database error was:", exc_info=True)
            if self.db_url.startswith('sqlite:///'):
                self._check_db_path(self.db_url.split(':///', 1)[1])
            self.log.critical('\n'.join([
                "If you recently upgraded JupyterHub, try running",
                "    jupyterhub upgrade-db",
                "to upgrade your JupyterHub database schema",
            ]))
            self.exit(1)

    def init_hub(self):
        """Load the Hub config into the database"""
        self.hub = self.db.query(orm.Hub).first()
        if self.hub is None:
            self.hub = orm.Hub(server=orm.Server(
                ip=self.hub_ip,
                port=self.hub_port,
                base_url=self.hub_prefix,
                cookie_name='jupyter-hub-token',
            ))
            self.db.add(self.hub)
        else:
            server = self.hub.server
            server.ip = self.hub_ip
            server.port = self.hub_port
            server.base_url = self.hub_prefix
        if self.subdomain_host:
            if not self.subdomain_host:
                raise ValueError(
                    "Must specify subdomain_host when using subdomains."
                    " This should be the public domain[:port] of the Hub.")

        self.db.commit()

    @gen.coroutine
    def init_users(self):
        """Load users into and from the database"""
        db = self.db

        if self.admin_users and not self.authenticator.admin_users:
            self.log.warning("\nJupyterHub.admin_users is deprecated."
                             "\nUse Authenticator.admin_users instead.")
            self.authenticator.admin_users = self.admin_users
        admin_users = [
            self.authenticator.normalize_username(name)
            for name in self.authenticator.admin_users
        ]
        self.authenticator.admin_users = set(
            admin_users)  # force normalization
        for username in admin_users:
            if not self.authenticator.validate_username(username):
                raise ValueError("username %r is not valid" % username)

        if not admin_users:
            self.log.warning(
                "No admin users, admin interface will be unavailable.")
            self.log.warning(
                "Add any administrative users to `c.Authenticator.admin_users` in config."
            )

        new_users = []

        for name in admin_users:
            # ensure anyone specified as admin in config is admin in db
            user = orm.User.find(db, name)
            if user is None:
                user = orm.User(name=name, admin=True)
                new_users.append(user)
                db.add(user)
            else:
                user.admin = True

        # the admin_users config variable will never be used after this point.
        # only the database values will be referenced.

        whitelist = [
            self.authenticator.normalize_username(name)
            for name in self.authenticator.whitelist
        ]
        self.authenticator.whitelist = set(whitelist)  # force normalization
        for username in whitelist:
            if not self.authenticator.validate_username(username):
                raise ValueError("username %r is not valid" % username)

        if not whitelist:
            self.log.info(
                "Not using whitelist. Any authenticated user will be allowed.")

        # add whitelisted users to the db
        for name in whitelist:
            user = orm.User.find(db, name)
            if user is None:
                user = orm.User(name=name)
                new_users.append(user)
                db.add(user)

        db.commit()

        # Notify authenticator of all users.
        # This ensures Auth whitelist is up-to-date with the database.
        # This lets whitelist be used to set up initial list,
        # but changes to the whitelist can occur in the database,
        # and persist across sessions.
        for user in db.query(orm.User):
            yield gen.maybe_future(self.authenticator.add_user(user))
        db.commit()  # can add_user touch the db?

        # The whitelist set and the users in the db are now the same.
        # From this point on, any user changes should be done simultaneously
        # to the whitelist set and user db, unless the whitelist is empty (all users allowed).

    def init_groups(self):
        """Load predefined groups into the database"""
        db = self.db
        for name, usernames in self.load_groups.items():
            group = orm.Group.find(db, name)
            if group is None:
                group = orm.Group(name=name)
                db.add(group)
            for username in usernames:
                username = self.authenticator.normalize_username(username)
                if not self.authenticator.check_whitelist(username):
                    raise ValueError("Username %r is not in whitelist" %
                                     username)
                user = orm.User.find(db, name=username)
                if user is None:
                    if not self.authenticator.validate_username(username):
                        raise ValueError("Group username %r is not valid" %
                                         username)
                    user = orm.User(name=username)
                    db.add(user)
                group.users.append(user)
        db.commit()

    def init_api_tokens(self):
        """Load predefined API tokens (for services) into database"""
        db = self.db
        for token, username in self.api_tokens.items():
            username = self.authenticator.normalize_username(username)
            if not self.authenticator.check_whitelist(username):
                raise ValueError("Token username %r is not in whitelist" %
                                 username)
            if not self.authenticator.validate_username(username):
                raise ValueError("Token username %r is not valid" % username)
            orm_token = orm.APIToken.find(db, token)
            if orm_token is None:
                user = orm.User.find(db, username)
                user_created = False
                if user is None:
                    user_created = True
                    self.log.debug("Adding user %r to database", username)
                    user = orm.User(name=username)
                    db.add(user)
                    db.commit()
                self.log.info("Adding API token for %s", username)
                try:
                    user.new_api_token(token)
                except Exception:
                    if user_created:
                        # don't allow bad tokens to create users
                        db.delete(user)
                        db.commit()
                        raise
            else:
                self.log.debug("Not duplicating token %s", orm_token)
        db.commit()

    @gen.coroutine
    def init_spawners(self):
        db = self.db

        user_summaries = ['']

        def _user_summary(user):
            parts = ['{0: >8}'.format(user.name)]
            if user.admin:
                parts.append('admin')
            if user.server:
                parts.append('running at %s' % user.server)
            return ' '.join(parts)

        @gen.coroutine
        def user_stopped(user):
            status = yield user.spawner.poll()
            self.log.warning(
                "User %s server stopped with exit code: %s",
                user.name,
                status,
            )
            yield self.proxy.delete_user(user)
            yield user.stop()

        for orm_user in db.query(orm.User):
            self.users[orm_user.id] = user = User(orm_user,
                                                  self.tornado_settings)
            if not user.state:
                # without spawner state, server isn't valid
                user.server = None
                user_summaries.append(_user_summary(user))
                continue
            self.log.debug("Loading state for %s from db", user.name)
            spawner = user.spawner
            status = yield spawner.poll()
            if status is None:
                self.log.info("%s still running", user.name)
                spawner.add_poll_callback(user_stopped, user)
                spawner.start_polling()
            else:
                # user not running. This is expected if server is None,
                # but indicates the user's server died while the Hub wasn't running
                # if user.server is defined.
                log = self.log.warning if user.server else self.log.debug
                log("%s not running.", user.name)
                user.server = None

            user_summaries.append(_user_summary(user))

        self.log.debug("Loaded users: %s", '\n'.join(user_summaries))
        db.commit()

    def init_proxy(self):
        """Load the Proxy config into the database"""
        self.proxy = self.db.query(orm.Proxy).first()
        if self.proxy is None:
            self.proxy = orm.Proxy(
                public_server=orm.Server(),
                api_server=orm.Server(),
            )
            self.db.add(self.proxy)
            self.db.commit()
        self.proxy.auth_token = self.proxy_auth_token  # not persisted
        self.proxy.log = self.log
        self.proxy.public_server.ip = self.ip
        self.proxy.public_server.port = self.port
        self.proxy.public_server.base_url = self.base_url
        self.proxy.api_server.ip = self.proxy_api_ip
        self.proxy.api_server.port = self.proxy_api_port
        self.proxy.api_server.base_url = '/api/routes/'
        self.db.commit()

    @gen.coroutine
    def start_proxy(self):
        """Actually start the configurable-http-proxy"""
        # check for proxy
        if self.proxy.public_server.is_up() or self.proxy.api_server.is_up():
            # check for *authenticated* access to the proxy (auth token can change)
            try:
                yield self.proxy.get_routes()
            except (HTTPError, OSError, socket.error) as e:
                if isinstance(e, HTTPError) and e.code == 403:
                    msg = "Did CONFIGPROXY_AUTH_TOKEN change?"
                else:
                    msg = "Is something else using %s?" % self.proxy.public_server.bind_url
                self.log.error(
                    "Proxy appears to be running at %s, but I can't access it (%s)\n%s",
                    self.proxy.public_server.bind_url, e, msg)
                self.exit(1)
                return
            else:
                self.log.info("Proxy already running at: %s",
                              self.proxy.public_server.bind_url)
            self.proxy_process = None
            return

        env = os.environ.copy()
        env['CONFIGPROXY_AUTH_TOKEN'] = self.proxy.auth_token
        cmd = self.proxy_cmd + [
            '--ip',
            self.proxy.public_server.ip,
            '--port',
            str(self.proxy.public_server.port),
            '--api-ip',
            self.proxy.api_server.ip,
            '--api-port',
            str(self.proxy.api_server.port),
            '--default-target',
            self.hub.server.host,
            '--error-target',
            url_path_join(self.hub.server.url, 'error'),
        ]
        if self.subdomain_host:
            cmd.append('--host-routing')
        if self.debug_proxy:
            cmd.extend(['--log-level', 'debug'])
        if self.ssl_key:
            cmd.extend(['--ssl-key', self.ssl_key])
        if self.ssl_cert:
            cmd.extend(['--ssl-cert', self.ssl_cert])
        if self.statsd_host:
            cmd.extend([
                '--statsd-host', self.statsd_host, '--statsd-port',
                str(self.statsd_port), '--statsd-prefix',
                self.statsd_prefix + '.chp'
            ])
        # Require SSL to be used or `--no-ssl` to confirm no SSL on
        if ' --ssl' not in ' '.join(cmd):
            if self.confirm_no_ssl:
                self.log.warning(
                    "Running JupyterHub without SSL."
                    " There better be SSL termination happening somewhere else..."
                )
            else:
                self.log.error(
                    "Refusing to run JuptyterHub without SSL."
                    " If you are terminating SSL in another layer,"
                    " pass --no-ssl to tell JupyterHub to allow the proxy to listen on HTTP."
                )
                self.exit(1)
        self.log.info("Starting proxy @ %s", self.proxy.public_server.bind_url)
        self.log.debug("Proxy cmd: %s", cmd)
        try:
            self.proxy_process = Popen(cmd, env=env)
        except FileNotFoundError as e:
            self.log.error(
                "Failed to find proxy %r\n"
                "The proxy can be installed with `npm install -g configurable-http-proxy`"
                % self.proxy_cmd)
            self.exit(1)

        def _check():
            status = self.proxy_process.poll()
            if status is not None:
                e = RuntimeError("Proxy failed to start with exit code %i" %
                                 status)
                # py2-compatible `raise e from None`
                e.__cause__ = None
                raise e

        for server in (self.proxy.public_server, self.proxy.api_server):
            for i in range(10):
                _check()
                try:
                    yield server.wait_up(1)
                except TimeoutError:
                    continue
                else:
                    break
            yield server.wait_up(1)
        self.log.debug("Proxy started and appears to be up")

    @gen.coroutine
    def check_proxy(self):
        if self.proxy_process.poll() is None:
            return
        self.log.error(
            "Proxy stopped with exit code %r", 'unknown'
            if self.proxy_process is None else self.proxy_process.poll())
        yield self.start_proxy()
        self.log.info("Setting up routes on new proxy")
        yield self.proxy.add_all_users(self.users)
        self.log.info("New proxy back up, and good to go")

    def init_tornado_settings(self):
        """Set up the tornado settings dict."""
        base_url = self.hub.server.base_url
        jinja_options = dict(autoescape=True, )
        jinja_options.update(self.jinja_environment_options)
        jinja_env = Environment(loader=FileSystemLoader(self.template_paths),
                                **jinja_options)

        login_url = self.authenticator.login_url(base_url)
        logout_url = self.authenticator.logout_url(base_url)

        # if running from git, disable caching of require.js
        # otherwise cache based on server start time
        parent = os.path.dirname(os.path.dirname(jupyterhub.__file__))
        if os.path.isdir(os.path.join(parent, '.git')):
            version_hash = ''
        else:
            version_hash = datetime.now().strftime("%Y%m%d%H%M%S"),

        subdomain_host = self.subdomain_host
        domain = urlparse(subdomain_host).hostname
        settings = dict(
            log_function=log_request,
            config=self.config,
            log=self.log,
            db=self.db,
            proxy=self.proxy,
            hub=self.hub,
            admin_users=self.authenticator.admin_users,
            admin_access=self.admin_access,
            authenticator=self.authenticator,
            spawner_class=self.spawner_class,
            base_url=self.base_url,
            cookie_secret=self.cookie_secret,
            cookie_max_age_days=self.cookie_max_age_days,
            login_url=login_url,
            logout_url=logout_url,
            static_path=os.path.join(self.data_files_path, 'static'),
            static_url_prefix=url_path_join(self.hub.server.base_url,
                                            'static/'),
            static_handler_class=CacheControlStaticFilesHandler,
            template_path=self.template_paths,
            jinja2_env=jinja_env,
            version_hash=version_hash,
            subdomain_host=subdomain_host,
            domain=domain,
            statsd=self.statsd,
        )
        # allow configured settings to have priority
        settings.update(self.tornado_settings)
        self.tornado_settings = settings
        # constructing users requires access to tornado_settings
        self.tornado_settings['users'] = self.users

    def init_tornado_application(self):
        """Instantiate the tornado Application object"""
        self.tornado_application = web.Application(self.handlers,
                                                   **self.tornado_settings)

    def write_pid_file(self):
        pid = os.getpid()
        if self.pid_file:
            self.log.debug("Writing PID %i to %s", pid, self.pid_file)
            with open(self.pid_file, 'w') as f:
                f.write('%i' % pid)

    @gen.coroutine
    @catch_config_error
    def initialize(self, *args, **kwargs):
        super().initialize(*args, **kwargs)
        if self.generate_config or self.subapp:
            return
        self.load_config_file(self.config_file)
        self.init_logging()
        if 'JupyterHubApp' in self.config:
            self.log.warning(
                "Use JupyterHub in config, not JupyterHubApp. Outdated config:\n%s",
                '\n'.join('JupyterHubApp.{key} = {value!r}'.format(key=key,
                                                                   value=value)
                          for key, value in self.config.JupyterHubApp.items()))
            cfg = self.config.copy()
            cfg.JupyterHub.merge(cfg.JupyterHubApp)
            self.update_config(cfg)
        self.write_pid_file()
        self.init_ports()
        self.init_secrets()
        self.init_db()
        self.init_hub()
        self.init_proxy()
        yield self.init_users()
        self.init_groups()
        self.init_api_tokens()
        self.init_tornado_settings()
        yield self.init_spawners()
        self.init_handlers()
        self.init_tornado_application()

    @gen.coroutine
    def cleanup(self):
        """Shutdown our various subprocesses and cleanup runtime files."""

        futures = []
        if self.cleanup_servers:
            self.log.info("Cleaning up single-user servers...")
            # request (async) process termination
            for uid, user in self.users.items():
                if user.spawner is not None:
                    futures.append(user.stop())
        else:
            self.log.info("Leaving single-user servers running")

        # clean up proxy while SUS are shutting down
        if self.cleanup_proxy:
            if self.proxy_process:
                self.log.info("Cleaning up proxy[%i]...",
                              self.proxy_process.pid)
                if self.proxy_process.poll() is None:
                    try:
                        self.proxy_process.terminate()
                    except Exception as e:
                        self.log.error("Failed to terminate proxy process: %s",
                                       e)
            else:
                self.log.info("I didn't start the proxy, I can't clean it up")
        else:
            self.log.info("Leaving proxy running")

        # wait for the requests to stop finish:
        for f in futures:
            try:
                yield f
            except Exception as e:
                self.log.error("Failed to stop user: %s", e)

        self.db.commit()

        if self.pid_file and os.path.exists(self.pid_file):
            self.log.info("Cleaning up PID file %s", self.pid_file)
            os.remove(self.pid_file)

        # finally stop the loop once we are all cleaned up
        self.log.info("...done")

    def write_config_file(self):
        """Write our default config to a .py config file"""
        if os.path.exists(self.config_file) and not self.answer_yes:
            answer = ''

            def ask():
                prompt = "Overwrite %s with default config? [y/N]" % self.config_file
                try:
                    return input(prompt).lower() or 'n'
                except KeyboardInterrupt:
                    print('')  # empty line
                    return 'n'

            answer = ask()
            while not answer.startswith(('y', 'n')):
                print("Please answer 'yes' or 'no'")
                answer = ask()
            if answer.startswith('n'):
                return

        config_text = self.generate_config_file()
        if isinstance(config_text, bytes):
            config_text = config_text.decode('utf8')
        print("Writing default config to: %s" % self.config_file)
        with open(self.config_file, mode='w') as f:
            f.write(config_text)

    @gen.coroutine
    def update_last_activity(self):
        """Update User.last_activity timestamps from the proxy"""
        routes = yield self.proxy.get_routes()
        users_count = 0
        active_users_count = 0
        for prefix, route in routes.items():
            if 'user' not in route:
                # not a user route, ignore it
                continue
            user = orm.User.find(self.db, route['user'])
            if user is None:
                self.log.warning("Found no user for route: %s", route)
                continue
            try:
                dt = datetime.strptime(route['last_activity'], ISO8601_ms)
            except Exception:
                dt = datetime.strptime(route['last_activity'], ISO8601_s)
            user.last_activity = max(user.last_activity, dt)
            # FIXME: Make this configurable duration. 30 minutes for now!
            if (datetime.now() - user.last_activity).total_seconds() < 30 * 60:
                active_users_count += 1
            users_count += 1
        self.statsd.gauge('users.running', users_count)
        self.statsd.gauge('users.active', active_users_count)

        self.db.commit()
        yield self.proxy.check_routes(self.users, routes)

    @gen.coroutine
    def start(self):
        """Start the whole thing"""
        self.io_loop = loop = IOLoop.current()

        if self.subapp:
            self.subapp.start()
            loop.stop()
            return

        if self.generate_config:
            self.write_config_file()
            loop.stop()
            return

        # start the webserver
        self.http_server = tornado.httpserver.HTTPServer(
            self.tornado_application, xheaders=True)
        try:
            self.http_server.listen(self.hub_port, address=self.hub_ip)
        except Exception:
            self.log.error("Failed to bind hub to %s",
                           self.hub.server.bind_url)
            raise
        else:
            self.log.info("Hub API listening on %s", self.hub.server.bind_url)

        # start the proxy
        try:
            yield self.start_proxy()
        except Exception as e:
            self.log.critical("Failed to start proxy", exc_info=True)
            self.exit(1)
            return

        loop.add_callback(self.proxy.add_all_users, self.users)

        if self.proxy_process:
            # only check / restart the proxy if we started it in the first place.
            # this means a restarted Hub cannot restart a Proxy that its
            # predecessor started.
            pc = PeriodicCallback(self.check_proxy,
                                  1e3 * self.proxy_check_interval)
            pc.start()

        if self.last_activity_interval:
            pc = PeriodicCallback(self.update_last_activity,
                                  1e3 * self.last_activity_interval)
            pc.start()

        self.log.info("JupyterHub is now running at %s",
                      self.proxy.public_server.url)
        # register cleanup on both TERM and INT
        atexit.register(self.atexit)
        self.init_signal()

    def init_signal(self):
        signal.signal(signal.SIGTERM, self.sigterm)

    def sigterm(self, signum, frame):
        self.log.critical("Received SIGTERM, shutting down")
        self.io_loop.stop()
        self.atexit()

    _atexit_ran = False

    def atexit(self):
        """atexit callback"""
        if self._atexit_ran:
            return
        self._atexit_ran = True
        # run the cleanup step (in a new loop, because the interrupted one is unclean)
        IOLoop.clear_current()
        loop = IOLoop()
        loop.make_current()
        loop.run_sync(self.cleanup)

    def stop(self):
        if not self.io_loop:
            return
        if self.http_server:
            if self.io_loop._running:
                self.io_loop.add_callback(self.http_server.stop)
            else:
                self.http_server.stop()
        self.io_loop.add_callback(self.io_loop.stop)

    @gen.coroutine
    def launch_instance_async(self, argv=None):
        try:
            yield self.initialize(argv)
            yield self.start()
        except Exception as e:
            self.log.exception("")
            self.exit(1)

    @classmethod
    def launch_instance(cls, argv=None):
        self = cls.instance()
        loop = IOLoop.current()
        loop.add_callback(self.launch_instance_async, argv)
        try:
            loop.start()
        except KeyboardInterrupt:
            print("\nInterrupted")
示例#21
0
class IPKernelApp(BaseIPythonApplication, InteractiveShellApp,
        ConnectionFileMixin):
    name='ipython-kernel'
    aliases = Dict(kernel_aliases)
    flags = Dict(kernel_flags)
    classes = [IPythonKernel, ZMQInteractiveShell, ProfileDir, Session]
    # the kernel class, as an importstring
    kernel_class = Type('ipykernel.ipkernel.IPythonKernel',
                        klass='ipykernel.kernelbase.Kernel',
    help="""The Kernel subclass to be used.

    This should allow easy re-use of the IPKernelApp entry point
    to configure and launch kernels other than IPython's own.
    """).tag(config=True)
    kernel = Any()
    poller = Any() # don't restrict this even though current pollers are all Threads
    heartbeat = Instance(Heartbeat, allow_none=True)

    context = Any()
    shell_socket = Any()
    control_socket = Any()
    debugpy_socket = Any()
    debug_shell_socket = Any()
    stdin_socket = Any()
    iopub_socket = Any()
    iopub_thread = Any()
    control_thread = Any()

    ports = Dict()

    subcommands = {
        'install': (
            'ipykernel.kernelspec.InstallIPythonKernelSpecApp',
            'Install the IPython kernel'
        ),
    }

    # connection info:
    connection_dir = Unicode()

    @default('connection_dir')
    def _default_connection_dir(self):
        return jupyter_runtime_dir()

    @property
    def abs_connection_file(self):
        if os.path.basename(self.connection_file) == self.connection_file:
            return os.path.join(self.connection_dir, self.connection_file)
        else:
            return self.connection_file

    # streams, etc.
    no_stdout = Bool(False, help="redirect stdout to the null device").tag(config=True)
    no_stderr = Bool(False, help="redirect stderr to the null device").tag(config=True)
    trio_loop = Bool(False, help="Set main event loop.").tag(config=True)
    quiet = Bool(True, help="Only send stdout/stderr to output stream").tag(config=True)
    outstream_class = DottedObjectName('ipykernel.iostream.OutStream',
        help="The importstring for the OutStream factory").tag(config=True)
    displayhook_class = DottedObjectName('ipykernel.displayhook.ZMQDisplayHook',
        help="The importstring for the DisplayHook factory").tag(config=True)

    # polling
    parent_handle = Integer(int(os.environ.get('JPY_PARENT_PID') or 0),
        help="""kill this process if its parent dies.  On Windows, the argument
        specifies the HANDLE of the parent process, otherwise it is simply boolean.
        """).tag(config=True)
    interrupt = Integer(int(os.environ.get('JPY_INTERRUPT_EVENT') or 0),
        help="""ONLY USED ON WINDOWS
        Interrupt this process when the parent is signaled.
        """).tag(config=True)

    def init_crash_handler(self):
        sys.excepthook = self.excepthook

    def excepthook(self, etype, evalue, tb):
        # write uncaught traceback to 'real' stderr, not zmq-forwarder
        traceback.print_exception(etype, evalue, tb, file=sys.__stderr__)

    def init_poller(self):
        if sys.platform == 'win32':
            if self.interrupt or self.parent_handle:
                self.poller = ParentPollerWindows(self.interrupt, self.parent_handle)
        elif self.parent_handle and self.parent_handle != 1:
            # PID 1 (init) is special and will never go away,
            # only be reassigned.
            # Parent polling doesn't work if ppid == 1 to start with.
            self.poller = ParentPollerUnix()

    def _try_bind_socket(self, s, port):
        iface = '%s://%s' % (self.transport, self.ip)
        if self.transport == 'tcp':
            if port <= 0:
                port = s.bind_to_random_port(iface)
            else:
                s.bind("tcp://%s:%i" % (self.ip, port))
        elif self.transport == 'ipc':
            if port <= 0:
                port = 1
                path = "%s-%i" % (self.ip, port)
                while os.path.exists(path):
                    port = port + 1
                    path = "%s-%i" % (self.ip, port)
            else:
                path = "%s-%i" % (self.ip, port)
            s.bind("ipc://%s" % path)
        return port

    def _bind_socket(self, s, port):
        try:
            win_in_use = errno.WSAEADDRINUSE
        except AttributeError:
            win_in_use = None

        # Try up to 100 times to bind a port when in conflict to avoid
        # infinite attempts in bad setups
        max_attempts = 1 if port else 100
        for attempt in range(max_attempts):
            try:
                return self._try_bind_socket(s, port)
            except zmq.ZMQError as ze:
                # Raise if we have any error not related to socket binding
                if ze.errno != errno.EADDRINUSE and ze.errno != win_in_use:
                    raise
                if attempt == max_attempts - 1:
                    raise

    def write_connection_file(self):
        """write connection info to JSON file"""
        cf = self.abs_connection_file
        self.log.debug("Writing connection file: %s", cf)
        write_connection_file(cf, ip=self.ip, key=self.session.key, transport=self.transport,
        shell_port=self.shell_port, stdin_port=self.stdin_port, hb_port=self.hb_port,
        iopub_port=self.iopub_port, control_port=self.control_port)

    def cleanup_connection_file(self):
        cf = self.abs_connection_file
        self.log.debug("Cleaning up connection file: %s", cf)
        try:
            os.remove(cf)
        except (IOError, OSError):
            pass

        self.cleanup_ipc_files()

    def init_connection_file(self):
        if not self.connection_file:
            self.connection_file = "kernel-%s.json"%os.getpid()
        try:
            self.connection_file = filefind(self.connection_file, ['.', self.connection_dir])
        except IOError:
            self.log.debug("Connection file not found: %s", self.connection_file)
            # This means I own it, and I'll create it in this directory:
            ensure_dir_exists(os.path.dirname(self.abs_connection_file), 0o700)
            # Also, I will clean it up:
            atexit.register(self.cleanup_connection_file)
            return
        try:
            self.load_connection_file()
        except Exception:
            self.log.error("Failed to load connection file: %r", self.connection_file, exc_info=True)
            self.exit(1)

    def init_sockets(self):
        # Create a context, a session, and the kernel sockets.
        self.log.info("Starting the kernel at pid: %i", os.getpid())
        assert self.context is None, "init_sockets cannot be called twice!"
        self.context = context = zmq.Context()
        atexit.register(self.close)

        self.shell_socket = context.socket(zmq.ROUTER)
        self.shell_socket.linger = 1000
        self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
        self.log.debug("shell ROUTER Channel on port: %i" % self.shell_port)

        self.stdin_socket = context.socket(zmq.ROUTER)
        self.stdin_socket.linger = 1000
        self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port)
        self.log.debug("stdin ROUTER Channel on port: %i" % self.stdin_port)

        if hasattr(zmq, 'ROUTER_HANDOVER'):
            # set router-handover to workaround zeromq reconnect problems
            # in certain rare circumstances
            # see ipython/ipykernel#270 and zeromq/libzmq#2892
            self.shell_socket.router_handover = \
                self.stdin_socket.router_handover = 1

        self.init_control(context)
        self.init_iopub(context)

    def init_control(self, context):
        self.control_socket = context.socket(zmq.ROUTER)
        self.control_socket.linger = 1000
        self.control_port = self._bind_socket(self.control_socket, self.control_port)
        self.log.debug("control ROUTER Channel on port: %i" % self.control_port)

        self.debugpy_socket = context.socket(zmq.STREAM)
        self.debugpy_socket.linger = 1000

        self.debug_shell_socket = context.socket(zmq.DEALER)
        self.debug_shell_socket.linger = 1000
        if self.shell_socket.getsockopt(zmq.LAST_ENDPOINT):
            self.debug_shell_socket.connect(self.shell_socket.getsockopt(zmq.LAST_ENDPOINT))

        if hasattr(zmq, 'ROUTER_HANDOVER'):
            # set router-handover to workaround zeromq reconnect problems
            # in certain rare circumstances
            # see ipython/ipykernel#270 and zeromq/libzmq#2892
            self.control_socket.router_handover = 1

        self.control_thread = ControlThread(daemon=True)

    def init_iopub(self, context):
        self.iopub_socket = context.socket(zmq.PUB)
        self.iopub_socket.linger = 1000
        self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port)
        self.log.debug("iopub PUB Channel on port: %i" % self.iopub_port)
        self.configure_tornado_logger()
        self.iopub_thread = IOPubThread(self.iopub_socket, pipe=True)
        self.iopub_thread.start()
        # backward-compat: wrap iopub socket API in background thread
        self.iopub_socket = self.iopub_thread.background_socket

    def init_heartbeat(self):
        """start the heart beating"""
        # heartbeat doesn't share context, because it mustn't be blocked
        # by the GIL, which is accessed by libzmq when freeing zero-copy messages
        hb_ctx = zmq.Context()
        self.heartbeat = Heartbeat(hb_ctx, (self.transport, self.ip, self.hb_port))
        self.hb_port = self.heartbeat.port
        self.log.debug("Heartbeat REP Channel on port: %i" % self.hb_port)
        self.heartbeat.start()

    def close(self):
        """Close zmq sockets in an orderly fashion"""
        # un-capture IO before we start closing channels
        self.reset_io()
        self.log.info("Cleaning up sockets")
        if self.heartbeat:
            self.log.debug("Closing heartbeat channel")
            self.heartbeat.context.term()
        if self.iopub_thread:
            self.log.debug("Closing iopub channel")
            self.iopub_thread.stop()
            self.iopub_thread.close()

        if self.debugpy_socket and not self.debugpy_socket.closed:
            self.debugpy_socket.close()
        if self.debug_shell_socket and not self.debug_shell_socket.closed:
            self.debug_shell_socket.close()

        for channel in ('shell', 'control', 'stdin'):
            self.log.debug("Closing %s channel", channel)
            socket = getattr(self, channel + "_socket", None)
            if socket and not socket.closed:
                socket.close()
        self.log.debug("Terminating zmq context")
        self.context.term()
        self.log.debug("Terminated zmq context")

    def log_connection_info(self):
        """display connection info, and store ports"""
        basename = os.path.basename(self.connection_file)
        if basename == self.connection_file or \
            os.path.dirname(self.connection_file) == self.connection_dir:
            # use shortname
            tail = basename
        else:
            tail = self.connection_file
        lines = [
            "To connect another client to this kernel, use:",
            "    --existing %s" % tail,
        ]
        # log connection info
        # info-level, so often not shown.
        # frontends should use the %connect_info magic
        # to see the connection info
        for line in lines:
            self.log.info(line)
        # also raw print to the terminal if no parent_handle (`ipython kernel`)
        # unless log-level is CRITICAL (--quiet)
        if not self.parent_handle and self.log_level < logging.CRITICAL:
            print(_ctrl_c_message, file=sys.__stdout__)
            for line in lines:
                print(line, file=sys.__stdout__)

        self.ports = dict(shell=self.shell_port, iopub=self.iopub_port,
                                stdin=self.stdin_port, hb=self.hb_port,
                                control=self.control_port)

    def init_blackhole(self):
        """redirects stdout/stderr to devnull if necessary"""
        if self.no_stdout or self.no_stderr:
            blackhole = open(os.devnull, 'w')
            if self.no_stdout:
                sys.stdout = sys.__stdout__ = blackhole
            if self.no_stderr:
                sys.stderr = sys.__stderr__ = blackhole

    def init_io(self):
        """Redirect input streams and set a display hook."""
        if self.outstream_class:
            outstream_factory = import_item(str(self.outstream_class))
            if sys.stdout is not None:
                sys.stdout.flush()

            e_stdout = None if self.quiet else sys.__stdout__
            e_stderr = None if self.quiet else sys.__stderr__

            sys.stdout = outstream_factory(self.session, self.iopub_thread,
                                           'stdout',
                                           echo=e_stdout)
            if sys.stderr is not None:
                sys.stderr.flush()
            sys.stderr = outstream_factory(
                self.session, self.iopub_thread, "stderr", echo=e_stderr
            )
            if hasattr(sys.stderr, "_original_stdstream_copy"):

                for handler in self.log.handlers:
                    if isinstance(handler, StreamHandler) and (
                        handler.stream.buffer.fileno() == 2
                    ):
                        self.log.debug(
                            "Seeing logger to stderr, rerouting to raw filedescriptor."
                        )

                        handler.stream = TextIOWrapper(
                            FileIO(sys.stderr._original_stdstream_copy, "w")
                        )
        if self.displayhook_class:
            displayhook_factory = import_item(str(self.displayhook_class))
            self.displayhook = displayhook_factory(self.session, self.iopub_socket)
            sys.displayhook = self.displayhook

        self.patch_io()

    def reset_io(self):
        """restore original io

        restores state after init_io
        """
        sys.stdout = sys.__stdout__
        sys.stderr = sys.__stderr__
        sys.displayhook = sys.__displayhook__

    def patch_io(self):
        """Patch important libraries that can't handle sys.stdout forwarding"""
        try:
            import faulthandler
        except ImportError:
            pass
        else:
            # Warning: this is a monkeypatch of `faulthandler.enable`, watch for possible
            # updates to the upstream API and update accordingly (up-to-date as of Python 3.5):
            # https://docs.python.org/3/library/faulthandler.html#faulthandler.enable

            # change default file to __stderr__ from forwarded stderr
            faulthandler_enable = faulthandler.enable
            def enable(file=sys.__stderr__, all_threads=True, **kwargs):
                return faulthandler_enable(file=file, all_threads=all_threads, **kwargs)

            faulthandler.enable = enable

            if hasattr(faulthandler, 'register'):
                faulthandler_register = faulthandler.register
                def register(signum, file=sys.__stderr__, all_threads=True, chain=False, **kwargs):
                    return faulthandler_register(signum, file=file, all_threads=all_threads,
                                                 chain=chain, **kwargs)
                faulthandler.register = register

    def init_signal(self):
        signal.signal(signal.SIGINT, signal.SIG_IGN)

    def init_kernel(self):
        """Create the Kernel object itself"""
        shell_stream = ZMQStream(self.shell_socket)
        control_stream = ZMQStream(self.control_socket, self.control_thread.io_loop)
        debugpy_stream = ZMQStream(self.debugpy_socket, self.control_thread.io_loop)
        self.control_thread.start()
        kernel_factory = self.kernel_class.instance

        kernel = kernel_factory(parent=self, session=self.session,
                                control_stream=control_stream,
                                debugpy_stream=debugpy_stream,
                                debug_shell_socket=self.debug_shell_socket,
                                shell_stream=shell_stream, 
                                control_thread=self.control_thread,
                                iopub_thread=self.iopub_thread,
                                iopub_socket=self.iopub_socket,
                                stdin_socket=self.stdin_socket,
                                log=self.log,
                                profile_dir=self.profile_dir,
                                user_ns=self.user_ns,
        )
        kernel.record_ports({
            name + '_port': port for name, port in self.ports.items()
        })
        self.kernel = kernel

        # Allow the displayhook to get the execution count
        self.displayhook.get_execution_count = lambda: kernel.execution_count

    def init_gui_pylab(self):
        """Enable GUI event loop integration, taking pylab into account."""

        # Register inline backend as default
        # this is higher priority than matplotlibrc,
        # but lower priority than anything else (mpl.use() for instance).
        # This only affects matplotlib >= 1.5
        if not os.environ.get('MPLBACKEND'):
            os.environ['MPLBACKEND'] = 'module://ipykernel.pylab.backend_inline'

        # Provide a wrapper for :meth:`InteractiveShellApp.init_gui_pylab`
        # to ensure that any exception is printed straight to stderr.
        # Normally _showtraceback associates the reply with an execution,
        # which means frontends will never draw it, as this exception
        # is not associated with any execute request.

        shell = self.shell
        _showtraceback = shell._showtraceback
        try:
            # replace error-sending traceback with stderr
            def print_tb(etype, evalue, stb):
                print ("GUI event loop or pylab initialization failed",
                       file=sys.stderr)
                print (shell.InteractiveTB.stb2text(stb), file=sys.stderr)
            shell._showtraceback = print_tb
            InteractiveShellApp.init_gui_pylab(self)
        finally:
            shell._showtraceback = _showtraceback

    def init_shell(self):
        self.shell = getattr(self.kernel, 'shell', None)
        if self.shell:
            self.shell.configurables.append(self)

    def configure_tornado_logger(self):
        """ Configure the tornado logging.Logger.

        Must set up the tornado logger or else tornado will call
        basicConfig for the root logger which makes the root logger
        go to the real sys.stderr instead of the capture streams.
        This function mimics the setup of logging.basicConfig.
        """
        logger = logging.getLogger('tornado')
        handler = logging.StreamHandler()
        formatter = logging.Formatter(logging.BASIC_FORMAT)
        handler.setFormatter(formatter)
        logger.addHandler(handler)

    def _init_asyncio_patch(self):
        """set default asyncio policy to be compatible with tornado

        Tornado 6 (at least) is not compatible with the default
        asyncio implementation on Windows

        Pick the older SelectorEventLoopPolicy on Windows
        if the known-incompatible default policy is in use.

        do this as early as possible to make it a low priority and overrideable

        ref: https://github.com/tornadoweb/tornado/issues/2608

        FIXME: if/when tornado supports the defaults in asyncio,
               remove and bump tornado requirement for py38
        """
        if sys.platform.startswith("win") and sys.version_info >= (3, 8) and tornado.version_info < (6, 1):
            import asyncio
            try:
                from asyncio import (
                    WindowsProactorEventLoopPolicy,
                    WindowsSelectorEventLoopPolicy,
                )
            except ImportError:
                pass
                # not affected
            else:
                if type(asyncio.get_event_loop_policy()) is WindowsProactorEventLoopPolicy:
                    # WindowsProactorEventLoopPolicy is not compatible with tornado 6
                    # fallback to the pre-3.8 default of Selector
                    asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())

    def init_pdb(self):
        """Replace pdb with IPython's version that is interruptible.

        With the non-interruptible version, stopping pdb() locks up the kernel in a
        non-recoverable state.
        """
        import pdb
        from IPython.core import debugger
        if hasattr(debugger, "InterruptiblePdb"):
            # Only available in newer IPython releases:
            debugger.Pdb = debugger.InterruptiblePdb
            pdb.Pdb = debugger.Pdb
            pdb.set_trace = debugger.set_trace

    @catch_config_error
    def initialize(self, argv=None):
        self._init_asyncio_patch()
        super(IPKernelApp, self).initialize(argv)
        if self.subapp is not None:
            return

        self.init_pdb()
        self.init_blackhole()
        self.init_connection_file()
        self.init_poller()
        self.init_sockets()
        self.init_heartbeat()
        # writing/displaying connection info must be *after* init_sockets/heartbeat
        self.write_connection_file()
        # Log connection info after writing connection file, so that the connection
        # file is definitely available at the time someone reads the log.
        self.log_connection_info()
        self.init_io()
        try:
            self.init_signal()
        except:
            # Catch exception when initializing signal fails, eg when running the
            # kernel on a separate thread
            if self.log_level < logging.CRITICAL:
                self.log.error("Unable to initialize signal:", exc_info=True)
        self.init_kernel()
        # shell init steps
        self.init_path()
        self.init_shell()
        if self.shell:
            self.init_gui_pylab()
            self.init_extensions()
            self.init_code()
        # flush stdout/stderr, so that anything written to these streams during
        # initialization do not get associated with the first execution request
        sys.stdout.flush()
        sys.stderr.flush()

    def start(self):
        if self.subapp is not None:
            return self.subapp.start()
        if self.poller is not None:
            self.poller.start()
        self.kernel.start()
        self.io_loop = ioloop.IOLoop.current()
        if self.trio_loop:
            from ipykernel.trio_runner import TrioRunner
            tr = TrioRunner()
            tr.initialize(self.kernel, self.io_loop)
            try:
                tr.run()
            except KeyboardInterrupt:
                pass
        else:
            try:
                self.io_loop.start()
            except KeyboardInterrupt:
                pass
示例#22
0
class DockerProfilesSpawner(ProfilesSpawner):

    """DockerProfilesSpawner - leverages ProfilesSpawner to dynamically create DockerSpawner
        profiles dynamically by looking for docker images that end with "jupyterhub". Due to the
        profiles being dynamic the "profiles" config item from the ProfilesSpawner is renamed as
        "default_profiles". Please note that the "docker" and DockerSpawner packages are required
        for this spawner to work.
    """

    default_profiles = List(
        trait = Tuple( Unicode(), Unicode(), Type(Spawner), Dict() ),
        default_value = [],
        config = True,
        help = """List of profiles to offer in addition to docker images for selection. Signature is:
            List(Tuple( Unicode, Unicode, Type(Spawner), Dict )) corresponding to
            profile display name, unique key, Spawner class, dictionary of spawner config options.

            The first three values will be exposed in the input_template as {display}, {key}, and {type}"""
        )

    docker_spawner_args = Dict(
        default_value = {},
        config = True,
        help = "Args to pass to DockerSpawner."
    )

    jupyterhub_docker_tag_re = re.compile('^.*jupyterhub$')

    def _nvidia_args(self):
        try:
            resp = urllib.request.urlopen('http://localhost:3476/v1.0/docker/cli/json')
            body = resp.read().decode('utf-8')
            args =  json.loads(body)
            return dict(
                read_only_volumes={vol.split(':')[0]: vol.split(':')[1] for vol in args['Volumes']},
                extra_create_kwargs={"volume_driver": args['VolumeDriver']},
                extra_host_config={"devices": args['Devices']},
            )
        except urllib.error.URLError:
            return {}


    def _docker_profile(self, nvidia_args, image):
        spawner_args = dict(container_image=image, network_name=self.user.name)
        spawner_args.update(self.docker_spawner_args)
        spawner_args.update(nvidia_args)
        nvidia_enabled = "w/GPU" if len(nvidia_args) > 0 else "no GPU"
        return ("Docker: (%s): %s"%(nvidia_enabled, image), "docker-%s"%(image), "dockerspawner.SystemUserSpawner", spawner_args)

    def _jupyterhub_docker_tags(self):
        try:
            include_jh_tags = lambda tag: self.jupyterhub_docker_tag_re.match(tag)
            return filter(include_jh_tags, [tag for image in docker.from_env().images.list() for tag in image.tags])
        except NameError:
            raise Exception('The docker package is not installed and is a dependency for DockerProfilesSpawner')

    def _docker_profiles(self):
        return [self._docker_profile(self._nvidia_args(), tag) for tag in self._jupyterhub_docker_tags()]

    @property
    def profiles(self):
        return self.default_profiles + self._docker_profiles()

    @property
    def options_form(self):
        temp_keys = [ dict(display=p[0], key=p[1], type=p[2], first='') for p in self.profiles]
        temp_keys[0]['first'] = self.first_template
        text = ''.join([ self.input_template.format(**tk) for tk in temp_keys ])
        return self.form_template.format(input_template=text)
示例#23
0
class IPKernelApp(BaseIPythonApplication, InteractiveShellApp,
                  ConnectionFileMixin):
    name = 'ipython-kernel'
    aliases = Dict(kernel_aliases)
    flags = Dict(kernel_flags)
    classes = [IPythonKernel, ZMQInteractiveShell, ProfileDir, Session]
    # the kernel class, as an importstring
    kernel_class = Type('ipykernel.ipkernel.IPythonKernel',
                        config=True,
                        klass='ipykernel.kernelbase.Kernel',
                        help="""The Kernel subclass to be used.

    This should allow easy re-use of the IPKernelApp entry point
    to configure and launch kernels other than IPython's own.
    """)
    kernel = Any()
    poller = Any(
    )  # don't restrict this even though current pollers are all Threads
    heartbeat = Instance(Heartbeat, allow_none=True)
    ports = Dict()

    subcommands = {
        'install': ('ipykernel.kernelspec.InstallIPythonKernelSpecApp',
                    'Install the IPython kernel'),
    }

    # connection info:
    connection_dir = Unicode()

    def _connection_dir_default(self):
        return jupyter_runtime_dir()

    @property
    def abs_connection_file(self):
        if os.path.basename(self.connection_file) == self.connection_file:
            return os.path.join(self.connection_dir, self.connection_file)
        else:
            return self.connection_file

    # streams, etc.
    no_stdout = Bool(False,
                     config=True,
                     help="redirect stdout to the null device")
    no_stderr = Bool(False,
                     config=True,
                     help="redirect stderr to the null device")
    outstream_class = DottedObjectName(
        'ipykernel.iostream.OutStream',
        config=True,
        help="The importstring for the OutStream factory")
    displayhook_class = DottedObjectName(
        'ipykernel.displayhook.ZMQDisplayHook',
        config=True,
        help="The importstring for the DisplayHook factory")

    # polling
    parent_handle = Integer(
        int(os.environ.get('JPY_PARENT_PID') or 0),
        config=True,
        help="""kill this process if its parent dies.  On Windows, the argument
        specifies the HANDLE of the parent process, otherwise it is simply boolean.
        """)
    interrupt = Integer(int(os.environ.get('JPY_INTERRUPT_EVENT') or 0),
                        config=True,
                        help="""ONLY USED ON WINDOWS
        Interrupt this process when the parent is signaled.
        """)

    def init_crash_handler(self):
        sys.excepthook = self.excepthook

    def excepthook(self, etype, evalue, tb):
        # write uncaught traceback to 'real' stderr, not zmq-forwarder
        traceback.print_exception(etype, evalue, tb, file=sys.__stderr__)

    def init_poller(self):
        if sys.platform == 'win32':
            if self.interrupt or self.parent_handle:
                self.poller = ParentPollerWindows(self.interrupt,
                                                  self.parent_handle)
        elif self.parent_handle:
            self.poller = ParentPollerUnix()

    def _bind_socket(self, s, port):
        iface = '%s://%s' % (self.transport, self.ip)
        if self.transport == 'tcp':
            if port <= 0:
                port = s.bind_to_random_port(iface)
            else:
                s.bind("tcp://%s:%i" % (self.ip, port))
        elif self.transport == 'ipc':
            if port <= 0:
                port = 1
                path = "%s-%i" % (self.ip, port)
                while os.path.exists(path):
                    port = port + 1
                    path = "%s-%i" % (self.ip, port)
            else:
                path = "%s-%i" % (self.ip, port)
            s.bind("ipc://%s" % path)
        return port

    def write_connection_file(self):
        """write connection info to JSON file"""
        cf = self.abs_connection_file
        self.log.debug("Writing connection file: %s", cf)
        write_connection_file(cf,
                              ip=self.ip,
                              key=self.session.key,
                              transport=self.transport,
                              shell_port=self.shell_port,
                              stdin_port=self.stdin_port,
                              hb_port=self.hb_port,
                              iopub_port=self.iopub_port,
                              control_port=self.control_port)

    def cleanup_connection_file(self):
        cf = self.abs_connection_file
        self.log.debug("Cleaning up connection file: %s", cf)
        try:
            os.remove(cf)
        except (IOError, OSError):
            pass

        self.cleanup_ipc_files()

    def init_connection_file(self):
        if not self.connection_file:
            self.connection_file = "kernel-%s.json" % os.getpid()
        try:
            self.connection_file = filefind(self.connection_file,
                                            ['.', self.connection_dir])
        except IOError:
            self.log.debug("Connection file not found: %s",
                           self.connection_file)
            # This means I own it, and I'll create it in this directory:
            ensure_dir_exists(os.path.dirname(self.abs_connection_file), 0o700)
            # Also, I will clean it up:
            atexit.register(self.cleanup_connection_file)
            return
        try:
            self.load_connection_file()
        except Exception:
            self.log.error("Failed to load connection file: %r",
                           self.connection_file,
                           exc_info=True)
            self.exit(1)

    def init_sockets(self):
        # Create a context, a session, and the kernel sockets.
        self.log.info("Starting the kernel at pid: %i", os.getpid())
        context = zmq.Context.instance()
        # Uncomment this to try closing the context.
        # atexit.register(context.term)

        self.shell_socket = context.socket(zmq.ROUTER)
        self.shell_socket.linger = 1000
        self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
        self.log.debug("shell ROUTER Channel on port: %i" % self.shell_port)

        self.stdin_socket = context.socket(zmq.ROUTER)
        self.stdin_socket.linger = 1000
        self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port)
        self.log.debug("stdin ROUTER Channel on port: %i" % self.stdin_port)

        self.control_socket = context.socket(zmq.ROUTER)
        self.control_socket.linger = 1000
        self.control_port = self._bind_socket(self.control_socket,
                                              self.control_port)
        self.log.debug("control ROUTER Channel on port: %i" %
                       self.control_port)

        self.init_iopub(context)

    def init_iopub(self, context):
        self.iopub_socket = context.socket(zmq.PUB)
        self.iopub_socket.linger = 1000
        self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port)
        self.log.debug("iopub PUB Channel on port: %i" % self.iopub_port)
        self.iopub_thread = IOPubThread(self.iopub_socket, pipe=True)
        self.iopub_thread.start()
        # backward-compat: wrap iopub socket API in background thread
        self.iopub_socket = self.iopub_thread.background_socket

    def init_heartbeat(self):
        """start the heart beating"""
        # heartbeat doesn't share context, because it mustn't be blocked
        # by the GIL, which is accessed by libzmq when freeing zero-copy messages
        hb_ctx = zmq.Context()
        self.heartbeat = Heartbeat(hb_ctx,
                                   (self.transport, self.ip, self.hb_port))
        self.hb_port = self.heartbeat.port
        self.log.debug("Heartbeat REP Channel on port: %i" % self.hb_port)
        self.heartbeat.start()

    def log_connection_info(self):
        """display connection info, and store ports"""
        basename = os.path.basename(self.connection_file)
        if basename == self.connection_file or \
            os.path.dirname(self.connection_file) == self.connection_dir:
            # use shortname
            tail = basename
        else:
            tail = self.connection_file
        lines = [
            "To connect another client to this kernel, use:",
            "    --existing %s" % tail,
        ]
        # log connection info
        # info-level, so often not shown.
        # frontends should use the %connect_info magic
        # to see the connection info
        for line in lines:
            self.log.info(line)
        # also raw print to the terminal if no parent_handle (`ipython kernel`)
        if not self.parent_handle:
            io.rprint(_ctrl_c_message)
            for line in lines:
                io.rprint(line)

        self.ports = dict(shell=self.shell_port,
                          iopub=self.iopub_port,
                          stdin=self.stdin_port,
                          hb=self.hb_port,
                          control=self.control_port)

    def init_blackhole(self):
        """redirects stdout/stderr to devnull if necessary"""
        if self.no_stdout or self.no_stderr:
            blackhole = open(os.devnull, 'w')
            if self.no_stdout:
                sys.stdout = sys.__stdout__ = blackhole
            if self.no_stderr:
                sys.stderr = sys.__stderr__ = blackhole

    def init_io(self):
        """Redirect input streams and set a display hook."""
        if self.outstream_class:
            outstream_factory = import_item(str(self.outstream_class))
            sys.stdout = outstream_factory(self.session, self.iopub_thread,
                                           u'stdout')
            sys.stderr = outstream_factory(self.session, self.iopub_thread,
                                           u'stderr')
        if self.displayhook_class:
            displayhook_factory = import_item(str(self.displayhook_class))
            sys.displayhook = displayhook_factory(self.session,
                                                  self.iopub_socket)

        self.patch_io()

    def patch_io(self):
        """Patch important libraries that can't handle sys.stdout forwarding"""
        try:
            import faulthandler
        except ImportError:
            pass
        else:
            # Warning: this is a monkeypatch of `faulthandler.enable`, watch for possible
            # updates to the upstream API and update accordingly (up-to-date as of Python 3.5):
            # https://docs.python.org/3/library/faulthandler.html#faulthandler.enable

            # change default file to __stderr__ from forwarded stderr
            faulthandler_enable = faulthandler.enable

            def enable(file=sys.__stderr__, all_threads=True, **kwargs):
                return faulthandler_enable(file=file,
                                           all_threads=all_threads,
                                           **kwargs)

            faulthandler.enable = enable

            if hasattr(faulthandler, 'register'):
                faulthandler_register = faulthandler.register

                def register(signum,
                             file=sys.__stderr__,
                             all_threads=True,
                             chain=False,
                             **kwargs):
                    return faulthandler_register(signum,
                                                 file=file,
                                                 all_threads=all_threads,
                                                 chain=chain,
                                                 **kwargs)

                faulthandler.register = register

    def init_signal(self):
        signal.signal(signal.SIGINT, signal.SIG_IGN)

    def init_kernel(self):
        """Create the Kernel object itself"""
        shell_stream = ZMQStream(self.shell_socket)
        control_stream = ZMQStream(self.control_socket)

        kernel_factory = self.kernel_class.instance

        kernel = kernel_factory(
            parent=self,
            session=self.session,
            shell_streams=[shell_stream, control_stream],
            iopub_thread=self.iopub_thread,
            iopub_socket=self.iopub_socket,
            stdin_socket=self.stdin_socket,
            log=self.log,
            profile_dir=self.profile_dir,
            user_ns=self.user_ns,
        )
        kernel.record_ports(self.ports)
        self.kernel = kernel

    def init_gui_pylab(self):
        """Enable GUI event loop integration, taking pylab into account."""

        # Provide a wrapper for :meth:`InteractiveShellApp.init_gui_pylab`
        # to ensure that any exception is printed straight to stderr.
        # Normally _showtraceback associates the reply with an execution,
        # which means frontends will never draw it, as this exception
        # is not associated with any execute request.

        shell = self.shell
        _showtraceback = shell._showtraceback
        try:
            # replace error-sending traceback with stderr
            def print_tb(etype, evalue, stb):
                print("GUI event loop or pylab initialization failed",
                      file=io.stderr)
                print(shell.InteractiveTB.stb2text(stb), file=io.stderr)

            shell._showtraceback = print_tb
            InteractiveShellApp.init_gui_pylab(self)
        finally:
            shell._showtraceback = _showtraceback

    def init_shell(self):
        self.shell = getattr(self.kernel, 'shell', None)
        if self.shell:
            self.shell.configurables.append(self)

    def init_extensions(self):
        super(IPKernelApp, self).init_extensions()
        # BEGIN HARDCODED WIDGETS HACK
        # Ensure ipywidgets extension is loaded if available
        extension_man = self.shell.extension_manager
        if 'ipywidgets' not in extension_man.loaded:
            try:
                extension_man.load_extension('ipywidgets')
            except ImportError as e:
                self.log.debug(
                    'ipywidgets package not installed.  Widgets will not be available.'
                )
        # END HARDCODED WIDGETS HACK

    @catch_config_error
    def initialize(self, argv=None):
        super(IPKernelApp, self).initialize(argv)
        if self.subapp is not None:
            return
        self.init_blackhole()
        self.init_connection_file()
        self.init_poller()
        self.init_sockets()
        self.init_heartbeat()
        # writing/displaying connection info must be *after* init_sockets/heartbeat
        self.write_connection_file()
        # Log connection info after writing connection file, so that the connection
        # file is definitely available at the time someone reads the log.
        self.log_connection_info()
        self.init_io()
        self.init_signal()
        self.init_kernel()
        # shell init steps
        self.init_path()
        self.init_shell()
        if self.shell:
            self.init_gui_pylab()
            self.init_extensions()
            self.init_code()
        # flush stdout/stderr, so that anything written to these streams during
        # initialization do not get associated with the first execution request
        sys.stdout.flush()
        sys.stderr.flush()

    def start(self):
        if self.subapp is not None:
            return self.subapp.start()

        if self.poller is not None:
            self.poller.start()
        self.kernel.start()
        try:
            ioloop.IOLoop.instance().start()
        except KeyboardInterrupt:
            pass
示例#24
0
class ZMQInteractiveShell(InteractiveShell):
    """A subclass of InteractiveShell for ZMQ."""

    displayhook_class = Type(ZMQShellDisplayHook)
    display_pub_class = Type(ZMQDisplayPublisher)
    data_pub_class = Type('ipykernel.datapub.ZMQDataPublisher')
    kernel = Any()
    parent_header = Any()

    @default('banner1')
    def _default_banner1(self):
        return default_banner

    # Override the traitlet in the parent class, because there's no point using
    # readline for the kernel. Can be removed when the readline code is moved
    # to the terminal frontend.
    colors_force = CBool(True)
    readline_use = CBool(False)
    # autoindent has no meaning in a zmqshell, and attempting to enable it
    # will print a warning in the absence of readline.
    autoindent = CBool(False)

    exiter = Instance(ZMQExitAutocall)

    @default('exiter')
    def _default_exiter(self):
        return ZMQExitAutocall(self)

    @observe('exit_now')
    def _update_exit_now(self, change):
        """stop eventloop when exit_now fires"""
        if change['new']:
            loop = ioloop.IOLoop.instance()
            loop.add_timeout(time.time() + 0.1, loop.stop)

    keepkernel_on_exit = None

    # Over ZeroMQ, GUI control isn't done with PyOS_InputHook as there is no
    # interactive input being read; we provide event loop support in ipkernel
    @staticmethod
    def enable_gui(gui):
        from .eventloops import enable_gui as real_enable_gui
        try:
            real_enable_gui(gui)
        except ValueError as e:
            raise UsageError("%s" % e)

    def init_environment(self):
        """Configure the user's environment."""
        env = os.environ
        # These two ensure 'ls' produces nice coloring on BSD-derived systems
        env['TERM'] = 'xterm-color'
        env['CLICOLOR'] = '1'
        # Since normal pagers don't work at all (over pexpect we don't have
        # single-key control of the subprocess), try to disable paging in
        # subprocesses as much as possible.
        env['PAGER'] = 'cat'
        env['GIT_PAGER'] = 'cat'

    def init_hooks(self):
        super(ZMQInteractiveShell, self).init_hooks()
        self.set_hook('show_in_pager', page.as_hook(payloadpage.page), 99)

    def init_data_pub(self):
        """Delay datapub init until request, for deprecation warnings"""
        pass

    @property
    def data_pub(self):
        if not hasattr(self, '_data_pub'):
            warnings.warn(
                "InteractiveShell.data_pub is deprecated outside IPython parallel.",
                DeprecationWarning,
                stacklevel=2)

            self._data_pub = self.data_pub_class(parent=self)
            self._data_pub.session = self.display_pub.session
            self._data_pub.pub_socket = self.display_pub.pub_socket
        return self._data_pub

    @data_pub.setter
    def data_pub(self, pub):
        self._data_pub = pub

    def ask_exit(self):
        """Engage the exit actions."""
        self.exit_now = (not self.keepkernel_on_exit)
        payload = dict(
            source='ask_exit',
            keepkernel=self.keepkernel_on_exit,
        )
        self.payload_manager.write_payload(payload)

    def run_cell(self, *args, **kwargs):
        self._last_traceback = None
        return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)

    def _showtraceback(self, etype, evalue, stb):
        # try to preserve ordering of tracebacks and print statements
        sys.stdout.flush()
        sys.stderr.flush()

        exc_content = {
            u'traceback': stb,
            u'ename': unicode_type(etype.__name__),
            u'evalue': py3compat.safe_unicode(evalue),
        }

        dh = self.displayhook
        # Send exception info over pub socket for other clients than the caller
        # to pick up
        topic = None
        if dh.topic:
            topic = dh.topic.replace(b'execute_result', b'error')

        exc_msg = dh.session.send(dh.pub_socket,
                                  u'error',
                                  json_clean(exc_content),
                                  dh.parent_header,
                                  ident=topic)

        # FIXME - Once we rely on Python 3, the traceback is stored on the
        # exception object, so we shouldn't need to store it here.
        self._last_traceback = stb

    def set_next_input(self, text, replace=False):
        """Send the specified text to the frontend to be presented at the next
        input cell."""
        payload = dict(
            source='set_next_input',
            text=text,
            replace=replace,
        )
        self.payload_manager.write_payload(payload)

    def set_parent(self, parent):
        """Set the parent header for associating output with its triggering input"""
        self.parent_header = parent
        self.displayhook.set_parent(parent)
        self.display_pub.set_parent(parent)
        if hasattr(self, '_data_pub'):
            self.data_pub.set_parent(parent)
        try:
            sys.stdout.set_parent(parent)
        except AttributeError:
            pass
        try:
            sys.stderr.set_parent(parent)
        except AttributeError:
            pass

    def get_parent(self):
        return self.parent_header

    def init_magics(self):
        super(ZMQInteractiveShell, self).init_magics()
        self.register_magics(KernelMagics)
        self.magics_manager.register_alias('ed', 'edit')

    def init_virtualenv(self):
        # Overridden not to do virtualenv detection, because it's probably
        # not appropriate in a kernel. To use a kernel in a virtualenv, install
        # it inside the virtualenv.
        # http://ipython.readthedocs.org/en/latest/install/kernel_install.html
        pass
示例#25
0
class ProxyInteractiveShell(InteractiveShell):
    display_pub_class = Type(ProxyDisplayPublisher)

    def enable_gui(gui, kernel=None):
        pass
示例#26
0
class ConnectionFileMixin(LoggingConfigurable):
    """Mixin for configurable classes that work with connection files"""

    data_dir = Unicode()
    def _data_dir_default(self):
        return jupyter_data_dir()

    # The addresses for the communication channels
    connection_file = Unicode('', config=True,
    help="""JSON file in which to store connection info [default: kernel-<pid>.json]

    This file will contain the IP, ports, and authentication key needed to connect
    clients to this kernel. By default, this file will be created in the security dir
    of the current profile, but can be specified by absolute path.
    """)
    _connection_file_written = Bool(False)

    transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
    kernel_name = Unicode()

    ip = Unicode(config=True,
        help="""Set the kernel\'s IP address [default localhost].
        If the IP address is something other than localhost, then
        Consoles on other machines will be able to connect
        to the Kernel, so be careful!"""
    )

    def _ip_default(self):
        if self.transport == 'ipc':
            if self.connection_file:
                return os.path.splitext(self.connection_file)[0] + '-ipc'
            else:
                return 'kernel-ipc'
        else:
            return localhost()

    @observe('ip')
    def _ip_changed(self, change):
        if change['new'] == '*':
            self.ip = '0.0.0.0'

    # protected traits

    hb_port = Integer(0, config=True,
            help="set the heartbeat port [default: random]")
    shell_port = Integer(0, config=True,
            help="set the shell (ROUTER) port [default: random]")
    iopub_port = Integer(0, config=True,
            help="set the iopub (PUB) port [default: random]")
    stdin_port = Integer(0, config=True,
            help="set the stdin (ROUTER) port [default: random]")
    control_port = Integer(0, config=True,
            help="set the control (ROUTER) port [default: random]")

    # names of the ports with random assignment
    _random_port_names = None

    @property
    def ports(self):
        return [ getattr(self, name) for name in port_names ]

    # The Session to use for communication with the kernel.
    session = Instance('jupyter_client.session.Session')
    def _session_default(self):
        from jupyter_client.session import Session
        return Session(parent=self)

    #--------------------------------------------------------------------------
    # Connection and ipc file management
    #--------------------------------------------------------------------------

    def get_connection_info(self, session=False):
        """Return the connection info as a dict

        Parameters
        ----------
        session : bool [default: False]
            If True, return our session object will be included in the connection info.
            If False (default), the configuration parameters of our session object will be included,
            rather than the session object itself.

        Returns
        -------
        connect_info : dict
            dictionary of connection information.
        """
        info = dict(
            transport=self.transport,
            ip=self.ip,
            shell_port=self.shell_port,
            iopub_port=self.iopub_port,
            stdin_port=self.stdin_port,
            hb_port=self.hb_port,
            control_port=self.control_port,
        )
        if session:
            # add *clone* of my session,
            # so that state such as digest_history is not shared.
            info['session'] = self.session.clone()
        else:
            # add session info
            info.update(dict(
                signature_scheme=self.session.signature_scheme,
                key=self.session.key,
            ))
        return info

    # factory for blocking clients
    blocking_class = Type(klass=object, default_value='jupyter_client.BlockingKernelClient')
    def blocking_client(self):
        """Make a blocking client connected to my kernel"""
        info = self.get_connection_info()
        info['parent'] = self
        bc = self.blocking_class(**info)
        bc.session.key = self.session.key
        return bc

    def cleanup_connection_file(self):
        """Cleanup connection file *if we wrote it*

        Will not raise if the connection file was already removed somehow.
        """
        if self._connection_file_written:
            # cleanup connection files on full shutdown of kernel we started
            self._connection_file_written = False
            try:
                os.remove(self.connection_file)
            except (IOError, OSError, AttributeError):
                pass

    def cleanup_ipc_files(self):
        """Cleanup ipc files if we wrote them."""
        if self.transport != 'ipc':
            return
        for port in self.ports:
            ipcfile = "%s-%i" % (self.ip, port)
            try:
                os.remove(ipcfile)
            except (IOError, OSError):
                pass

    def _record_random_port_names(self):
        """Records which of the ports are randomly assigned.

        Records on first invocation, if the transport is tcp.
        Does nothing on later invocations."""

        if self.transport != 'tcp':
            return
        if self._random_port_names is not None:
            return

        self._random_port_names = []
        for name in port_names:
            if getattr(self, name) <= 0:
                self._random_port_names.append(name)

    def cleanup_random_ports(self):
        """Forgets randomly assigned port numbers and cleans up the connection file.

        Does nothing if no port numbers have been randomly assigned.
        In particular, does nothing unless the transport is tcp.
        """

        if not self._random_port_names:
            return

        for name in self._random_port_names:
            setattr(self, name, 0)

        self.cleanup_connection_file()

    def write_connection_file(self):
        """Write connection info to JSON dict in self.connection_file."""
        if self._connection_file_written and os.path.exists(self.connection_file):
            return

        self.connection_file, cfg = write_connection_file(self.connection_file,
            transport=self.transport, ip=self.ip, key=self.session.key,
            stdin_port=self.stdin_port, iopub_port=self.iopub_port,
            shell_port=self.shell_port, hb_port=self.hb_port,
            control_port=self.control_port,
            signature_scheme=self.session.signature_scheme,
            kernel_name=self.kernel_name
        )
        # write_connection_file also sets default ports:
        self._record_random_port_names()
        for name in port_names:
            setattr(self, name, cfg[name])

        self._connection_file_written = True

    def load_connection_file(self, connection_file=None):
        """Load connection info from JSON dict in self.connection_file.

        Parameters
        ----------
        connection_file: unicode, optional
            Path to connection file to load.
            If unspecified, use self.connection_file
        """
        if connection_file is None:
            connection_file = self.connection_file
        self.log.debug(u"Loading connection file %s", connection_file)
        with open(connection_file) as f:
            info = json.load(f)
        self.load_connection_info(info)

    def load_connection_info(self, info):
        """Load connection info from a dict containing connection info.

        Typically this data comes from a connection file
        and is called by load_connection_file.

        Parameters
        ----------
        info: dict
            Dictionary containing connection_info.
            See the connection_file spec for details.
        """
        self.transport = info.get('transport', self.transport)
        self.ip = info.get('ip', self._ip_default())

        self._record_random_port_names()
        for name in port_names:
            if getattr(self, name) == 0 and name in info:
                # not overridden by config or cl_args
                setattr(self, name, info[name])

        if 'key' in info:
            self.session.key = cast_bytes(info['key'])
        if 'signature_scheme' in info:
            self.session.signature_scheme = info['signature_scheme']

    #--------------------------------------------------------------------------
    # Creating connected sockets
    #--------------------------------------------------------------------------

    def _make_url(self, channel):
        """Make a ZeroMQ URL for a given channel."""
        transport = self.transport
        ip = self.ip
        port = getattr(self, '%s_port' % channel)

        if transport == 'tcp':
            return "tcp://%s:%i" % (ip, port)
        else:
            return "%s://%s-%s" % (transport, ip, port)

    def _create_connected_socket(self, channel, identity=None):
        """Create a zmq Socket and connect it to the kernel."""
        url = self._make_url(channel)
        socket_type = channel_socket_types[channel]
        self.log.debug("Connecting to: %s" % url)
        sock = self.context.socket(socket_type)
        # set linger to 1s to prevent hangs at exit
        sock.linger = 1000
        if identity:
            sock.identity = identity
        sock.connect(url)
        return sock

    def connect_iopub(self, identity=None):
        """return zmq Socket connected to the IOPub channel"""
        sock = self._create_connected_socket('iopub', identity=identity)
        sock.setsockopt(zmq.SUBSCRIBE, b'')
        return sock

    def connect_shell(self, identity=None):
        """return zmq Socket connected to the Shell channel"""
        return self._create_connected_socket('shell', identity=identity)

    def connect_stdin(self, identity=None):
        """return zmq Socket connected to the StdIn channel"""
        return self._create_connected_socket('stdin', identity=identity)

    def connect_hb(self, identity=None):
        """return zmq Socket connected to the Heartbeat channel"""
        return self._create_connected_socket('hb', identity=identity)

    def connect_control(self, identity=None):
        """return zmq Socket connected to the Control channel"""
        return self._create_connected_socket('control', identity=identity)
示例#27
0
class HTTPKernelManager(AsyncKernelManager):
    """Manages a single kernel remotely via a Gateway Server. """

    kernel_id = None
    kernel = None

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.base_endpoint = url_path_join(
            GatewayClient.instance().url,
            GatewayClient.instance().kernels_endpoint)
        self.kernel = None

    def _get_kernel_endpoint_url(self, kernel_id=None):
        """Builds a url for the kernels endpoint

        Parameters
        ----------
        kernel_id: kernel UUID (optional)
        """
        if kernel_id:
            return url_path_join(self.base_endpoint,
                                 url_escape(str(kernel_id)))

        return self.base_endpoint

    @property
    def has_kernel(self):
        """Has a kernel been started that we are managing."""
        return self.kernel is not None

    client_class = DottedObjectName(
        'elyra.pipeline.http_kernel_manager.HTTPKernelClient')
    client_factory = Type(
        klass='elyra.pipeline.http_kernel_manager.HTTPKernelClient')

    # --------------------------------------------------------------------------
    # create a Client connected to our Kernel
    # --------------------------------------------------------------------------

    def client(self, **kwargs):
        """Create a client configured to connect to our kernel"""
        kw = {}
        kw.update(self.get_connection_info(session=True))
        kw.update(dict(
            connection_file=self.connection_file,
            parent=self,
        ))
        kw['kernel_id'] = self.kernel_id

        # add kwargs last, for manual overrides
        kw.update(kwargs)
        return self.client_factory(**kw)

    async def get_kernel(self, kernel_id):
        """Get kernel for kernel_id.

        Parameters
        ----------
        kernel_id : uuid
            The uuid of the kernel.
        """
        kernel_url = self._get_kernel_endpoint_url(kernel_id)
        self.log.debug("Request kernel at: %s" % kernel_url)
        try:
            response = await gateway_request(kernel_url, method='GET')
        except web.HTTPError as error:
            if error.status_code == 404:
                self.log.warning("Kernel not found at: %s" % kernel_url)
                kernel = None
            else:
                raise
        else:
            kernel = json_decode(response.body)
        self.log.debug("Kernel retrieved: %s" % kernel)
        return kernel

    # --------------------------------------------------------------------------
    # Kernel management
    # --------------------------------------------------------------------------

    async def start_kernel(self, **kwargs):
        """Starts a kernel via HTTP in an asynchronous manner.

        Parameters
        ----------
        `**kwargs` : optional
             keyword arguments that are passed down to build the kernel_cmd
             and launching the kernel (e.g. Popen kwargs).
        """
        kernel_id = kwargs.get('kernel_id')

        if kernel_id is None:
            kernel_name = kwargs.get('kernel_name', 'python3')
            kernel_url = self._get_kernel_endpoint_url()
            self.log.debug("Request new kernel at: %s" % kernel_url)

            # Let KERNEL_USERNAME take precedent over http_user config option.
            if os.environ.get('KERNEL_USERNAME'
                              ) is None and GatewayClient.instance().http_user:
                os.environ['KERNEL_USERNAME'] = GatewayClient.instance(
                ).http_user

            kernel_env = {
                k: v
                for (k, v) in dict(os.environ).items()
                if k.startswith('KERNEL_')
                or k in GatewayClient.instance().env_whitelist.split(",")
            }

            # Add any env entries in this request
            kernel_env.update(kwargs.get('env'))

            # Convey the full path to where this notebook file is located.
            if kwargs.get('cwd') is not None and kernel_env.get(
                    'KERNEL_WORKING_DIR') is None:
                kernel_env['KERNEL_WORKING_DIR'] = kwargs['cwd']

            json_body = json_encode({'name': kernel_name, 'env': kernel_env})

            response = await gateway_request(kernel_url,
                                             method='POST',
                                             body=json_body)
            self.kernel = json_decode(response.body)
            self.kernel_id = self.kernel['id']
            self.log.info(
                "HTTPKernelManager started kernel: {}, args: {}".format(
                    self.kernel_id, kwargs))
        else:
            self.kernel = await self.get_kernel(kernel_id)
            self.kernel_id = self.kernel['id']
            self.log.info("HTTPKernelManager using existing kernel: {}".format(
                self.kernel_id))

    async def shutdown_kernel(self, now=False, restart=False):
        """Attempts to stop the kernel process cleanly via HTTP. """

        if self.has_kernel:
            kernel_url = self._get_kernel_endpoint_url(self.kernel_id)
            self.log.debug("Request shutdown kernel at: %s", kernel_url)
            response = await gateway_request(kernel_url, method='DELETE')
            self.log.debug("Shutdown kernel response: %d %s", response.code,
                           response.reason)

    async def restart_kernel(self, **kw):
        """Restarts a kernel via HTTP.  """
        if self.has_kernel:
            kernel_url = self._get_kernel_endpoint_url(
                self.kernel_id) + '/restart'
            self.log.debug("Request restart kernel at: %s", kernel_url)
            response = await gateway_request(kernel_url,
                                             method='POST',
                                             body=json_encode({}))
            self.log.debug("Restart kernel response: %d %s", response.code,
                           response.reason)

    async def interrupt_kernel(self):
        """Interrupts the kernel via an HTTP request. """
        if self.has_kernel:
            kernel_url = self._get_kernel_endpoint_url(
                self.kernel_id) + '/interrupt'
            self.log.debug("Request interrupt kernel at: %s", kernel_url)
            response = await gateway_request(kernel_url,
                                             method='POST',
                                             body=json_encode({}))
            self.log.debug("Interrupt kernel response: %d %s", response.code,
                           response.reason)

    async def is_alive(self):
        """Is the kernel process still running?"""
        if self.has_kernel:
            # Go ahead and issue a request to get the kernel
            self.kernel = await self.get_kernel(self.kernel_id)
            return True
        else:  # we don't have a kernel
            return False

    def cleanup_resources(self, restart=False):
        """Clean up resources when the kernel is shut down"""
        pass
示例#28
0
class KernelManager(ConnectionFileMixin):
    """Manages a single kernel in a subprocess on this host.

    This version starts kernels with Popen.
    """

    _created_context = Bool(False)

    # The PyZMQ Context to use for communication with the kernel.
    context = Instance(zmq.Context)
    def _context_default(self):
        self._created_context = True
        return zmq.Context()

    # the class to create with our `client` method
    client_class = DottedObjectName('jupyter_client.blocking.BlockingKernelClient')
    client_factory = Type(klass='jupyter_client.KernelClient')
    def _client_factory_default(self):
        return import_item(self.client_class)

    @observe('client_class')
    def _client_class_changed(self, change):
        self.client_factory = import_item(str(change['new']))

    # The kernel process with which the KernelManager is communicating.
    # generally a Popen instance
    kernel = Any()

    kernel_spec_manager = Instance(kernelspec.KernelSpecManager)

    def _kernel_spec_manager_default(self):
        return kernelspec.KernelSpecManager(data_dir=self.data_dir)

    @observe('kernel_spec_manager')
    @observe_compat
    def _kernel_spec_manager_changed(self, change):
        self._kernel_spec = None

    shutdown_wait_time = Float(
        5.0, config=True,
        help="Time to wait for a kernel to terminate before killing it, "
             "in seconds.")

    kernel_name = Unicode(kernelspec.NATIVE_KERNEL_NAME)

    @observe('kernel_name')
    def _kernel_name_changed(self, change):
        self._kernel_spec = None
        if change['new'] == 'python':
            self.kernel_name = kernelspec.NATIVE_KERNEL_NAME

    _kernel_spec = None

    @property
    def kernel_spec(self):
        if self._kernel_spec is None and self.kernel_name != '':
            self._kernel_spec = self.kernel_spec_manager.get_kernel_spec(self.kernel_name)
        return self._kernel_spec

    kernel_cmd = List(Unicode(), config=True,
        help="""DEPRECATED: Use kernel_name instead.

        The Popen Command to launch the kernel.
        Override this if you have a custom kernel.
        If kernel_cmd is specified in a configuration file,
        Jupyter does not pass any arguments to the kernel,
        because it cannot make any assumptions about the
        arguments that the kernel understands. In particular,
        this means that the kernel does not receive the
        option --debug if it given on the Jupyter command line.
        """
    )

    def _kernel_cmd_changed(self, name, old, new):
        warnings.warn("Setting kernel_cmd is deprecated, use kernel_spec to "
                      "start different kernels.")

    cache_ports = Bool(help='True if the MultiKernelManager should cache ports for this KernelManager instance')

    @default('cache_ports')
    def _default_cache_ports(self):
        return self.transport == 'tcp'

    @property
    def ipykernel(self):
        return self.kernel_name in {'python', 'python2', 'python3'}

    # Protected traits
    _launch_args = Any()
    _control_socket = Any()

    _restarter = Any()

    autorestart = Bool(True, config=True,
        help="""Should we autorestart the kernel if it dies."""
    )

    shutting_down = False

    def __del__(self):
        self._close_control_socket()
        self.cleanup_connection_file()

    #--------------------------------------------------------------------------
    # Kernel restarter
    #--------------------------------------------------------------------------

    def start_restarter(self):
        pass

    def stop_restarter(self):
        pass

    def add_restart_callback(self, callback, event='restart'):
        """register a callback to be called when a kernel is restarted"""
        if self._restarter is None:
            return
        self._restarter.add_callback(callback, event)

    def remove_restart_callback(self, callback, event='restart'):
        """unregister a callback to be called when a kernel is restarted"""
        if self._restarter is None:
            return
        self._restarter.remove_callback(callback, event)

    #--------------------------------------------------------------------------
    # create a Client connected to our Kernel
    #--------------------------------------------------------------------------

    def client(self, **kwargs):
        """Create a client configured to connect to our kernel"""
        kw = {}
        kw.update(self.get_connection_info(session=True))
        kw.update(dict(
            connection_file=self.connection_file,
            parent=self,
        ))

        # add kwargs last, for manual overrides
        kw.update(kwargs)
        return self.client_factory(**kw)

    #--------------------------------------------------------------------------
    # Kernel management
    #--------------------------------------------------------------------------

    def format_kernel_cmd(self, extra_arguments=None):
        """replace templated args (e.g. {connection_file})"""
        extra_arguments = extra_arguments or []
        if self.kernel_cmd:
            cmd = self.kernel_cmd + extra_arguments
        else:
            cmd = self.kernel_spec.argv + extra_arguments

        if cmd and cmd[0] in {'python',
                              'python%i' % sys.version_info[0],
                              'python%i.%i' % sys.version_info[:2]}:
            # executable is 'python' or 'python3', use sys.executable.
            # These will typically be the same,
            # but if the current process is in an env
            # and has been launched by abspath without
            # activating the env, python on PATH may not be sys.executable,
            # but it should be.
            cmd[0] = sys.executable

        # Make sure to use the realpath for the connection_file
        # On windows, when running with the store python, the connection_file path
        # is not usable by non python kernels because the path is being rerouted when
        # inside of a store app.
        # See this bug here: https://bugs.python.org/issue41196
        ns = dict(connection_file=os.path.realpath(self.connection_file),
                  prefix=sys.prefix,
                 )

        if self.kernel_spec:
            ns["resource_dir"] = self.kernel_spec.resource_dir

        ns.update(self._launch_args)

        pat = re.compile(r'\{([A-Za-z0-9_]+)\}')
        def from_ns(match):
            """Get the key out of ns if it's there, otherwise no change."""
            return ns.get(match.group(1), match.group())

        return [ pat.sub(from_ns, arg) for arg in cmd ]

    def _launch_kernel(self, kernel_cmd, **kw):
        """actually launch the kernel

        override in a subclass to launch kernel subprocesses differently
        """
        return launch_kernel(kernel_cmd, **kw)

    # Control socket used for polite kernel shutdown

    def _connect_control_socket(self):
        if self._control_socket is None:
            self._control_socket = self._create_connected_socket('control')
            self._control_socket.linger = 100

    def _close_control_socket(self):
        if self._control_socket is None:
            return
        self._control_socket.close()
        self._control_socket = None

    def pre_start_kernel(self, **kw):
        """Prepares a kernel for startup in a separate process.

        If random ports (port=0) are being used, this method must be called
        before the channels are created.

        Parameters
        ----------
        `**kw` : optional
             keyword arguments that are passed down to build the kernel_cmd
             and launching the kernel (e.g. Popen kwargs).
        """
        self.shutting_down = False
        if self.transport == 'tcp' and not is_local_ip(self.ip):
            raise RuntimeError("Can only launch a kernel on a local interface. "
                               "This one is not: %s."
                               "Make sure that the '*_address' attributes are "
                               "configured properly. "
                               "Currently valid addresses are: %s" % (self.ip, local_ips())
                               )

        # write connection file / get default ports
        self.write_connection_file()

        # save kwargs for use in restart
        self._launch_args = kw.copy()
        # build the Popen cmd
        extra_arguments = kw.pop('extra_arguments', [])
        kernel_cmd = self.format_kernel_cmd(extra_arguments=extra_arguments)
        env = kw.pop('env', os.environ).copy()
        # Don't allow PYTHONEXECUTABLE to be passed to kernel process.
        # If set, it can bork all the things.
        env.pop('PYTHONEXECUTABLE', None)
        if not self.kernel_cmd:
            # If kernel_cmd has been set manually, don't refer to a kernel spec.
            # Environment variables from kernel spec are added to os.environ.
            env.update(self._get_env_substitutions(self.kernel_spec.env, env))

        kw['env'] = env
        return kernel_cmd, kw

    def _get_env_substitutions(self, templated_env, substitution_values):
        """ Walks env entries in templated_env and applies possible substitutions from current env
            (represented by substitution_values).
            Returns the substituted list of env entries.
        """
        substituted_env = {}
        if templated_env:
            from string import Template

            # For each templated env entry, fill any templated references
            # matching names of env variables with those values and build
            # new dict with substitutions.
            for k, v in templated_env.items():
                substituted_env.update({k: Template(v).safe_substitute(substitution_values)})
        return substituted_env

    def post_start_kernel(self, **kw):
        self.start_restarter()
        self._connect_control_socket()

    def start_kernel(self, **kw):
        """Starts a kernel on this host in a separate process.

        If random ports (port=0) are being used, this method must be called
        before the channels are created.

        Parameters
        ----------
        `**kw` : optional
             keyword arguments that are passed down to build the kernel_cmd
             and launching the kernel (e.g. Popen kwargs).
        """
        kernel_cmd, kw = self.pre_start_kernel(**kw)

        # launch the kernel subprocess
        self.log.debug("Starting kernel: %s", kernel_cmd)
        self.kernel = self._launch_kernel(kernel_cmd, **kw)
        self.post_start_kernel(**kw)

    def request_shutdown(self, restart=False):
        """Send a shutdown request via control channel
        """
        content = dict(restart=restart)
        msg = self.session.msg("shutdown_request", content=content)
        # ensure control socket is connected
        self._connect_control_socket()
        self.session.send(self._control_socket, msg)

    def finish_shutdown(self, waittime=None, pollinterval=0.1):
        """Wait for kernel shutdown, then kill process if it doesn't shutdown.

        This does not send shutdown requests - use :meth:`request_shutdown`
        first.
        """
        if waittime is None:
            waittime = max(self.shutdown_wait_time, 0)
        for i in range(int(waittime/pollinterval)):
            if self.is_alive():
                time.sleep(pollinterval)
            else:
                # If there's still a proc, wait and clear
                if self.has_kernel:
                    self.kernel.wait()
                    self.kernel = None
                break
        else:
            # OK, we've waited long enough.
            if self.has_kernel:
                self.log.debug("Kernel is taking too long to finish, killing")
                self._kill_kernel()

    def cleanup_resources(self, restart=False):
        """Clean up resources when the kernel is shut down"""
        if not restart:
            self.cleanup_connection_file()

        self.cleanup_ipc_files()
        self._close_control_socket()
        self.session.parent = None

        if self._created_context and not restart:
            self.context.destroy(linger=100)

    def cleanup(self, connection_file=True):
        """Clean up resources when the kernel is shut down"""
        warnings.warn("Method cleanup(connection_file=True) is deprecated, use cleanup_resources(restart=False).",
                      FutureWarning)
        self.cleanup_resources(restart=not connection_file)

    def shutdown_kernel(self, now=False, restart=False):
        """Attempts to stop the kernel process cleanly.

        This attempts to shutdown the kernels cleanly by:

        1. Sending it a shutdown message over the control channel.
        2. If that fails, the kernel is shutdown forcibly by sending it
           a signal.

        Parameters
        ----------
        now : bool
            Should the kernel be forcible killed *now*. This skips the
            first, nice shutdown attempt.
        restart: bool
            Will this kernel be restarted after it is shutdown. When this
            is True, connection files will not be cleaned up.
        """
        self.shutting_down = True  # Used by restarter to prevent race condition
        # Stop monitoring for restarting while we shutdown.
        self.stop_restarter()

        if now:
            self._kill_kernel()
        else:
            self.request_shutdown(restart=restart)
            # Don't send any additional kernel kill messages immediately, to give
            # the kernel a chance to properly execute shutdown actions. Wait for at
            # most 1s, checking every 0.1s.
            self.finish_shutdown()

        # In 6.1.5, a new method, cleanup_resources(), was introduced to address
        # a leak issue (https://github.com/jupyter/jupyter_client/pull/548) and
        # replaced the existing cleanup() method.  However, that method introduction
        # breaks subclass implementations that override cleanup() since it would
        # circumvent cleanup() functionality implemented in subclasses.
        # By detecting if the current instance overrides cleanup(), we can determine
        # if the deprecated path of calling cleanup() should be performed - which avoids
        # unnecessary deprecation warnings in a majority of configurations in which
        # subclassed KernelManager instances are not in use.
        # Note: because subclasses may have already implemented cleanup_resources()
        # but need to support older jupyter_clients, we should only take the deprecated
        # path if cleanup() is overridden but cleanup_resources() is not.

        overrides_cleanup = type(self).cleanup is not KernelManager.cleanup
        overrides_cleanup_resources = type(self).cleanup_resources is not KernelManager.cleanup_resources

        if overrides_cleanup and not overrides_cleanup_resources:
            self.cleanup(connection_file=not restart)
        else:
            self.cleanup_resources(restart=restart)

    def restart_kernel(self, now=False, newports=False, **kw):
        """Restarts a kernel with the arguments that were used to launch it.

        Parameters
        ----------
        now : bool, optional
            If True, the kernel is forcefully restarted *immediately*, without
            having a chance to do any cleanup action.  Otherwise the kernel is
            given 1s to clean up before a forceful restart is issued.

            In all cases the kernel is restarted, the only difference is whether
            it is given a chance to perform a clean shutdown or not.

        newports : bool, optional
            If the old kernel was launched with random ports, this flag decides
            whether the same ports and connection file will be used again.
            If False, the same ports and connection file are used. This is
            the default. If True, new random port numbers are chosen and a
            new connection file is written. It is still possible that the newly
            chosen random port numbers happen to be the same as the old ones.

        `**kw` : optional
            Any options specified here will overwrite those used to launch the
            kernel.
        """
        if self._launch_args is None:
            raise RuntimeError("Cannot restart the kernel. "
                               "No previous call to 'start_kernel'.")
        else:
            # Stop currently running kernel.
            self.shutdown_kernel(now=now, restart=True)

            if newports:
                self.cleanup_random_ports()

            # Start new kernel.
            self._launch_args.update(kw)
            self.start_kernel(**self._launch_args)

    @property
    def has_kernel(self):
        """Has a kernel been started that we are managing."""
        return self.kernel is not None

    def _kill_kernel(self):
        """Kill the running kernel.

        This is a private method, callers should use shutdown_kernel(now=True).
        """
        if self.has_kernel:
            # Signal the kernel to terminate (sends SIGKILL on Unix and calls
            # TerminateProcess() on Win32).
            try:
                if hasattr(signal, 'SIGKILL'):
                    self.signal_kernel(signal.SIGKILL)
                else:
                    self.kernel.kill()
            except OSError as e:
                # In Windows, we will get an Access Denied error if the process
                # has already terminated. Ignore it.
                if sys.platform == 'win32':
                    if e.winerror != 5:
                        raise
                # On Unix, we may get an ESRCH error if the process has already
                # terminated. Ignore it.
                else:
                    from errno import ESRCH
                    if e.errno != ESRCH:
                        raise

            # Block until the kernel terminates.
            self.kernel.wait()
            self.kernel = None

    def interrupt_kernel(self):
        """Interrupts the kernel by sending it a signal.

        Unlike ``signal_kernel``, this operation is well supported on all
        platforms.
        """
        if self.has_kernel:
            interrupt_mode = self.kernel_spec.interrupt_mode
            if interrupt_mode == 'signal':
                if sys.platform == 'win32':
                    from .win_interrupt import send_interrupt
                    send_interrupt(self.kernel.win32_interrupt_event)
                else:
                    self.signal_kernel(signal.SIGINT)

            elif interrupt_mode == 'message':
                msg = self.session.msg("interrupt_request", content={})
                self._connect_control_socket()
                self.session.send(self._control_socket, msg)
        else:
            raise RuntimeError("Cannot interrupt kernel. No kernel is running!")

    def signal_kernel(self, signum):
        """Sends a signal to the process group of the kernel (this
        usually includes the kernel and any subprocesses spawned by
        the kernel).

        Note that since only SIGTERM is supported on Windows, this function is
        only useful on Unix systems.
        """
        if self.has_kernel:
            if hasattr(os, "getpgid") and hasattr(os, "killpg"):
                try:
                    pgid = os.getpgid(self.kernel.pid)
                    os.killpg(pgid, signum)
                    return
                except OSError:
                    pass
            self.kernel.send_signal(signum)
        else:
            raise RuntimeError("Cannot signal kernel. No kernel is running!")

    def is_alive(self):
        """Is the kernel process still running?"""
        if self.has_kernel:
            if self.kernel.poll() is None:
                return True
            else:
                return False
        else:
            # we don't have a kernel
            return False
示例#29
0
class KernelSpecManager(LoggingConfigurable):

    kernel_spec_class = Type(
        KernelSpec,
        config=True,
        help="""The kernel spec class.  This is configurable to allow
        subclassing of the KernelSpecManager for customized behavior.
        """)

    ensure_native_kernel = Bool(
        True,
        config=True,
        help="""If there is no Python kernelspec registered and the IPython
        kernel is available, ensure it is added to the spec list.
        """)

    data_dir = Unicode()

    def _data_dir_default(self):
        return jupyter_data_dir()

    user_kernel_dir = Unicode()

    def _user_kernel_dir_default(self):
        return pjoin(self.data_dir, 'kernels')

    whitelist = Set(config=True,
                    help="""Whitelist of allowed kernel names.

        By default, all installed kernels are allowed.
        """)
    kernel_dirs = List(
        help=
        "List of kernel directories to search. Later ones take priority over earlier."
    )

    def _kernel_dirs_default(self):
        dirs = jupyter_path('kernels')
        # At some point, we should stop adding .ipython/kernels to the path,
        # but the cost to keeping it is very small.
        try:
            from IPython.paths import get_ipython_dir
        except ImportError:
            try:
                from IPython.utils.path import get_ipython_dir
            except ImportError:
                # no IPython, no ipython dir
                get_ipython_dir = None
        if get_ipython_dir is not None:
            dirs.append(os.path.join(get_ipython_dir(), 'kernels'))
        return dirs

    def find_kernel_specs(self):
        """Returns a dict mapping kernel names to resource directories."""
        d = {}
        for kernel_dir in self.kernel_dirs:
            kernels = _list_kernels_in(kernel_dir)
            for kname, spec in kernels.items():
                if kname not in d:
                    self.log.debug("Found kernel %s in %s", kname, kernel_dir)
                    d[kname] = spec

        if self.ensure_native_kernel and NATIVE_KERNEL_NAME not in d:
            try:
                from ipykernel.kernelspec import RESOURCES
                self.log.debug("Native kernel (%s) available from %s",
                               NATIVE_KERNEL_NAME, RESOURCES)
                d[NATIVE_KERNEL_NAME] = RESOURCES
            except ImportError:
                self.log.warning("Native kernel (%s) is not available",
                                 NATIVE_KERNEL_NAME)

        if self.whitelist:
            # filter if there's a whitelist
            d = {
                name: spec
                for name, spec in d.items() if name in self.whitelist
            }
        return d
        # TODO: Caching?

    def _get_kernel_spec_by_name(self, kernel_name, resource_dir):
        """ Returns a :class:`KernelSpec` instance for a given kernel_name
        and resource_dir.
        """
        if kernel_name == NATIVE_KERNEL_NAME:
            try:
                from ipykernel.kernelspec import RESOURCES, get_kernel_dict
            except ImportError:
                # It should be impossible to reach this, but let's play it safe
                pass
            else:
                if resource_dir == RESOURCES:
                    return self.kernel_spec_class(resource_dir=resource_dir,
                                                  **get_kernel_dict())

        return self.kernel_spec_class.from_resource_dir(resource_dir)

    def get_kernel_spec(self, kernel_name):
        """Returns a :class:`KernelSpec` instance for the given kernel_name.

        Raises :exc:`NoSuchKernel` if the given kernel name is not found.
        """
        d = self.find_kernel_specs()
        try:
            resource_dir = d[kernel_name.lower()]
        except KeyError:
            raise NoSuchKernel(kernel_name)

        return self._get_kernel_spec_by_name(kernel_name, resource_dir)

    def get_all_specs(self):
        """Returns a dict mapping kernel names to kernelspecs.

        Returns a dict of the form::

            {
              'kernel_name': {
                'resource_dir': '/path/to/kernel_name',
                'spec': {"the spec itself": ...}
              },
              ...
            }
        """
        d = self.find_kernel_specs()
        return {
            kname: {
                "resource_dir": d[kname],
                "spec": self._get_kernel_spec_by_name(kname,
                                                      d[kname]).to_dict()
            }
            for kname in d
        }

    def remove_kernel_spec(self, name):
        """Remove a kernel spec directory by name.
        
        Returns the path that was deleted.
        """
        save_native = self.ensure_native_kernel
        try:
            self.ensure_native_kernel = False
            specs = self.find_kernel_specs()
        finally:
            self.ensure_native_kernel = save_native
        spec_dir = specs[name]
        self.log.debug("Removing %s", spec_dir)
        if os.path.islink(spec_dir):
            os.remove(spec_dir)
        else:
            shutil.rmtree(spec_dir)
        return spec_dir

    def _get_destination_dir(self, kernel_name, user=False, prefix=None):
        if user:
            return os.path.join(self.user_kernel_dir, kernel_name)
        elif prefix:
            return os.path.join(os.path.abspath(prefix), 'share', 'jupyter',
                                'kernels', kernel_name)
        else:
            return os.path.join(SYSTEM_JUPYTER_PATH[0], 'kernels', kernel_name)

    def install_kernel_spec(self,
                            source_dir,
                            kernel_name=None,
                            user=False,
                            replace=None,
                            prefix=None):
        """Install a kernel spec by copying its directory.

        If ``kernel_name`` is not given, the basename of ``source_dir`` will
        be used.

        If ``user`` is False, it will attempt to install into the systemwide
        kernel registry. If the process does not have appropriate permissions,
        an :exc:`OSError` will be raised.
        
        If ``prefix`` is given, the kernelspec will be installed to
        PREFIX/share/jupyter/kernels/KERNEL_NAME. This can be sys.prefix
        for installation inside virtual or conda envs.
        """
        source_dir = source_dir.rstrip('/\\')
        if not kernel_name:
            kernel_name = os.path.basename(source_dir)
        kernel_name = kernel_name.lower()
        if not _is_valid_kernel_name(kernel_name):
            raise ValueError("Invalid kernel name %r.  %s" %
                             (kernel_name, _kernel_name_description))

        if user and prefix:
            raise ValueError(
                "Can't specify both user and prefix. Please choose one or the other."
            )

        if replace is not None:
            warnings.warn(
                "replace is ignored. Installing a kernelspec always replaces an existing installation",
                DeprecationWarning,
                stacklevel=2,
            )

        destination = self._get_destination_dir(kernel_name,
                                                user=user,
                                                prefix=prefix)
        self.log.debug('Installing kernelspec in %s', destination)

        kernel_dir = os.path.dirname(destination)
        if kernel_dir not in self.kernel_dirs:
            self.log.warning(
                "Installing to %s, which is not in %s. The kernelspec may not be found.",
                kernel_dir,
                self.kernel_dirs,
            )

        if os.path.isdir(destination):
            self.log.info('Removing existing kernelspec in %s', destination)
            shutil.rmtree(destination)

        shutil.copytree(source_dir, destination)
        self.log.info('Installed kernelspec %s in %s', kernel_name,
                      destination)
        return destination

    def install_native_kernel_spec(self, user=False):
        """DEPRECATED: Use ipykernel.kenelspec.install"""
        warnings.warn(
            "install_native_kernel_spec is deprecated."
            " Use ipykernel.kernelspec import install.",
            stacklevel=2)
        from ipykernel.kernelspec import install
        install(self, user=user)
示例#30
0
class MetadataManager(LoggingConfigurable):
    """Manages metadata instances"""

    # System-owned namespaces
    NAMESPACE_RUNTIMES = "runtimes"
    NAMESPACE_CODE_SNIPPETS = "code-snippets"
    NAMESPACE_RUNTIME_IMAGES = "runtime-images"

    metadata_store_class = Type(default_value=FileMetadataStore, config=True,
                                klass=MetadataStore,
                                help="""The metadata store class.  This is configurable to allow subclassing of
                                the MetadataStore for customized behavior.""")

    def __init__(self, namespace: str, **kwargs: Any):
        """
        Generic object to manage metadata instances.
        :param namespace (str): the partition where metadata instances are stored
        :param kwargs: additional arguments to be used to instantiate a metadata manager
        Keyword Args:
            metadata_store_class (str): the name of the MetadataStore subclass to use for storing managed instances
        """
        super(MetadataManager, self).__init__(**kwargs)

        self.schema_mgr = SchemaManager.instance()
        self.schema_mgr.validate_namespace(namespace)
        self.namespace = namespace
        self.metadata_store = self.metadata_store_class(namespace, **kwargs)

    def namespace_exists(self) -> bool:
        """Returns True if the namespace for this instance exists"""
        return self.metadata_store.namespace_exists()

    def get_all(self, include_invalid: bool = False) -> List[Metadata]:
        """Returns all metadata instances in summary form (name, display_name, location)"""

        instances = []
        instance_list = self.metadata_store.fetch_instances(include_invalid=include_invalid)
        for metadata_dict in instance_list:
            # validate the instance prior to return, include invalid instances as appropriate
            try:
                metadata = Metadata.from_dict(self.namespace, metadata_dict)
                metadata.post_load()  # Allow class instances to handle loads
                # if we're including invalid and there was an issue on retrieval, add it to the list
                if include_invalid and metadata.reason:
                    # If no schema-name is present, set to '{unknown}' since we can't make that determination.
                    if not metadata.schema_name:
                        metadata.schema_name = '{unknown}'
                else:  # go ahead and validate against the schema
                    self.validate(metadata.name, metadata)
                instances.append(metadata)
            except Exception as ex:  # Ignore ValidationError and others when fetching all instances
                # Since we may not have a metadata instance due to a failure during `from_dict()`,
                # instantiate a bad instance directly to use in the message and invalid result.
                invalid_instance = Metadata(**metadata_dict)
                self.log.debug("Fetch of instance '{}' of namespace '{}' encountered an exception: {}".
                               format(invalid_instance.name, self.namespace, ex))
                if include_invalid:
                    invalid_instance.reason = ex.__class__.__name__
                    instances.append(invalid_instance)
        return instances

    def get(self, name: str) -> Metadata:
        """Returns the metadata instance corresponding to the given name"""
        instance_list = self.metadata_store.fetch_instances(name=name)
        metadata_dict = instance_list[0]
        metadata = Metadata.from_dict(self.namespace, metadata_dict)

        # Validate the instance on load
        self.validate(name, metadata)

        # Allow class instances to alter instance
        metadata.post_load()

        return metadata

    def create(self, name: str, metadata: Metadata) -> Metadata:
        """Creates the given metadata, returning the created instance"""
        return self._save(name, metadata)

    def update(self, name: str, metadata: Metadata) -> Metadata:
        """Updates the given metadata, returning the updated instance"""
        return self._save(name, metadata, for_update=True)

    def remove(self, name: str) -> None:
        """Removes the metadata instance corresponding to the given name"""

        instance_list = self.metadata_store.fetch_instances(name=name)
        metadata_dict = instance_list[0]

        self.log.debug("Removing metadata resource '{}' from namespace '{}'.".format(name, self.namespace))

        metadata = Metadata.from_dict(self.namespace, metadata_dict)
        metadata.pre_delete()  # Allow class instances to handle delete

        self.metadata_store.delete_instance(metadata_dict)

    def validate(self, name: str, metadata: Metadata) -> None:
        """Validate metadata against its schema.

        Ensure metadata is valid based on its schema.  If invalid or schema
        is not found, ValidationError will be raised.
        """
        metadata_dict = metadata.to_dict()
        schema_name = metadata_dict.get('schema_name')
        if not schema_name:
            raise ValueError("Instance '{}' in the {} namespace is missing a 'schema_name' field!".
                             format(name, self.namespace))

        schema = self._get_schema(schema_name)  # returns a value or throws
        try:
            validate(instance=metadata_dict, schema=schema, format_checker=draft7_format_checker)
        except ValidationError as ve:
            # Because validation errors are so verbose, only provide the first line.
            first_line = str(ve).partition('\n')[0]
            msg = "Validation failed for instance '{}' using the {} schema with error: {}.".\
                format(name, schema_name, first_line)
            self.log.error(msg)
            raise ValidationError(msg) from ve

    @staticmethod
    def _get_normalized_name(name: str) -> str:
        # lowercase and replaces spaces with underscore
        name = re.sub('\\s+', '_', name.lower())
        # remove all invalid characters
        name = re.sub('[^a-z0-9-_]+', '', name)
        # begin with alpha
        if not name[0].isalpha():
            name = 'a_' + name
        # end with alpha numeric
        if not name[-1].isalnum():
            name = name + '_0'
        return name

    def _get_schema(self, schema_name: str) -> dict:
        """Loads the schema based on the schema_name and returns the loaded schema json.
           Throws ValidationError if schema file is not present.
        """
        schema_json = self.schema_mgr.get_schema(self.namespace, schema_name)
        if schema_json is None:
            schema_file = os.path.join(os.path.dirname(__file__), 'schemas', schema_name + '.json')
            if not os.path.exists(schema_file):
                self.log.error("The file for schema '{}' is missing from its expected location: '{}'".
                               format(schema_name, schema_file))
                raise SchemaNotFoundError("The file for schema '{}' is missing!".format(schema_name))
            with io.open(schema_file, 'r', encoding='utf-8') as f:
                schema_json = json.load(f)
            self.schema_mgr.add_schema(self.namespace, schema_name, schema_json)

        return schema_json

    def _save(self, name: str, metadata: Metadata, for_update: bool = False) -> Metadata:
        if not metadata:
            raise ValueError("An instance of class 'Metadata' was not provided.")

        if not isinstance(metadata, Metadata):
            raise TypeError("'metadata' is not an instance of class 'Metadata'.")

        if not name and not for_update:  # name is derived from display_name only on creates
            if metadata.display_name:
                name = self._get_normalized_name(metadata.display_name)
                metadata.name = name

        if not name:  # At this point, name must be set
            raise ValueError('Name of metadata was not provided.')

        match = re.search("^[a-z]([a-z0-9-_]*[a-z,0-9])?$", name)
        if match is None:
            raise ValueError("Name of metadata must be lowercase alphanumeric, beginning with alpha and can include "
                             "embedded hyphens ('-') and underscores ('_').")

        # Allow class instances to handle saves
        metadata.pre_save(for_update=for_update)

        self._apply_defaults(metadata)

        # Validate the metadata prior to storage then store the instance.
        self.validate(name, metadata)

        metadata_dict = self.metadata_store.store_instance(name, metadata.prepare_write(), for_update=for_update)

        return Metadata.from_dict(self.namespace, metadata_dict)

    def _apply_defaults(self, metadata: Metadata) -> None:
        """If a given property has a default value defined, and that property is not currently represented,

        assign it the default value.
        """

        # Get the schema and build a dict consisting of properties and their default values (for those
        # properties that have defaults).  Then walk the metadata instance looking for missing properties
        # and assign the corresponding default value.  Note that we do not consider existing properties with
        # values of None for default replacement since that may be intentional (although those values will
        # likely fail subsequent validation).

        schema = self.schema_mgr.get_schema(self.namespace, metadata.schema_name)

        meta_properties = schema['properties']['metadata']['properties']
        property_defaults = {}
        for name, property in meta_properties.items():
            if 'default' in property:
                property_defaults[name] = property['default']

        if property_defaults:  # schema defines defaulted properties
            instance_properties = metadata.metadata
            for name, default in property_defaults.items():
                if name not in instance_properties:
                    instance_properties[name] = default