Exemple #1
0
class Scheduler(SessionFactory):
    client_stream = Instance(zmqstream.ZMQStream,
                             allow_none=True)  # client-facing stream
    engine_stream = Instance(zmqstream.ZMQStream,
                             allow_none=True)  # engine-facing stream
    notifier_stream = Instance(zmqstream.ZMQStream,
                               allow_none=True)  # hub-facing sub stream
    mon_stream = Instance(zmqstream.ZMQStream,
                          allow_none=True)  # hub-facing pub stream
    query_stream = Instance(zmqstream.ZMQStream,
                            allow_none=True)  # hub-facing DEALER stream

    all_completed = Set()  # set of all completed tasks
    all_failed = Set()  # set of all failed tasks
    all_done = Set()  # set of all finished tasks=union(completed,failed)
    all_ids = Set()  # set of all submitted task IDs

    ident = CBytes()  # ZMQ identity. This should just be self.session.session

    # but ensure Bytes
    def _ident_default(self):
        return self.session.bsession

    def start(self):
        self.engine_stream.on_recv(self.dispatch_result, copy=False)
        self.client_stream.on_recv(self.dispatch_submission, copy=False)

    def resume_receiving(self):
        """Resume accepting jobs."""
        self.client_stream.on_recv(self.dispatch_submission, copy=False)

    def stop_receiving(self):
        """Stop accepting jobs while there are no engines.
        Leave them in the ZMQ queue."""
        self.client_stream.on_recv(None)

    def dispatch_result(self, raw_msg):
        raise NotImplementedError("Implement in subclasses")

    def dispatch_submission(self, raw_msg):
        raise NotImplementedError("Implement in subclasses")

    def append_new_msg_id_to_msg(self, new_id, target_id, idents, msg):
        new_idents = [cast_bytes(target_id)] + idents
        msg['header']['msg_id'] = new_id
        new_msg_list = self.session.serialize(msg, ident=new_idents)
        new_msg_list.extend(msg['buffers'])
        return new_msg_list

    def get_new_msg_id(self, original_msg_id, outgoing_id):
        return f'{original_msg_id}_{outgoing_id if isinstance(outgoing_id, str) else outgoing_id.decode("utf8")}'
Exemple #2
0
class ExtractAttachmentsPreprocessor(Preprocessor):
    "Extracts all of the outputs from the notebook file."
    output_filename_template = Unicode("attach_{cell_index}_{name}").tag(
        config=True)
    extract_output_types = Set(
        {'image/png', 'image/jpeg', 'image/svg+xml',
         'application/pdf'}).tag(config=True)

    def preprocess_cell(self, cell, resources, cell_index):
        output_files_dir = resources.get('output_files_dir', None)
        if not isinstance(resources['outputs'], dict):
            resources['outputs'] = {}

        for name, attach in cell.get("attachments", {}).items():
            for mime, data in attach.items():
                if mime not in self.extract_output_types: continue
                # Binary files are base64-encoded, SVG is already XML
                if mime in {'image/png', 'image/jpeg', 'application/pdf'}:
                    data = a2b_base64(data)
                elif sys.platform == 'win32':
                    data = data.replace('\n', '\r\n').encode("UTF-8")
                else:
                    data = data.encode("UTF-8")
                filename = self.output_filename_template.format(
                    cell_index=cell_index, name=name)
                if output_files_dir is not None:
                    filename = os.path.join(output_files_dir, filename)
                if name.endswith(".gif") and mime == "image/png":
                    filename = filename.replace(".gif", ".png")
                resources['outputs'][filename] = data
                attach_str = "attachment:" + name
                if attach_str in cell.source:
                    cell.source = cell.source.replace(attach_str, filename)

        return cell, resources
Exemple #3
0
class Bar(Configurable):

    b = Integer(0, help="The integer b.").tag(config=True)
    enabled = Bool(True, help="Enable bar.").tag(config=True)
    tb = Tuple(()).tag(config=True, multiplicity='*')
    aset = Set().tag(config=True, multiplicity='+')
    bdict = Dict().tag(config=True)
Exemple #4
0
class TreeFilter(TreeControler):
    visible = Set(allow_none=True)

    def __init__(self, ref, attr, func, **kwargs):
        self._cbs = []
        super().__init__(ref, attr, func, **kwargs)
        self.tree = self.ref

    def ref_changed(self, change):
        super().ref_changed(change)
        self.tree = self.ref
        self.update(None)

    def update(self, change):
        super().update(change)
        if self.ref is not None and self.monitored is not None:
            self.visible = self.ref.reduce(self.apply)
        else:
            self.visible = None
        for func in self._cbs:
            func(self.visible)

    def on(self, func):
        self._cbs.append(func)

    def off(self, func):
        if func is None:
            self._cbs = []
        else:
            self._cbs.remove(func)
Exemple #5
0
class TagExtractPreprocessor(nbconvert.preprocessors.Preprocessor):

    extract_cell_tags = Set(
        Unicode(),
        default_value=[],
        help=("Tags indicating which cells are to be removed,"
              "matches tags in `cell.metadata.tags`.")).tag(config=True)

    def find_matching_tags(self, cell):
        return self.extract_cell_tags.intersection(
            cell.get('metadata', {}).get('tags', []))

    def preprocess(self, nb, resources):
        if not self.extract_cell_tags:
            return nb, resources

        # Filter out cells that meet the conditions
        new_cells = []
        extracted_by_tag = resources.setdefault('extracted_by_tag', {})

        for cell in nb.cells:
            tags = self.find_matching_tags(cell)
            if tags:
                for tag in tags:
                    extracted_by_tag.setdefault(tag, []).append(cell['source'])
            else:
                new_cells.append(cell)

        nb.cells = new_cells
        return nb, resources
class RemoveLessonCells(DsCreatePreprocessor):

    description = '''
    RemoveLessonCells removes cells that do not contain a tag included in the ``solution_tags`` variable.

    ``solution_tags`` are a  configurable variable. Defaults to {'#__SOLUTION__', '#==SOLUTION==', '__SOLUTION__', '==SOLUTION=='}
    '''

    solution_tags = Set(
        {'#__SOLUTION__', '#==SOLUTION==', '__SOLUTION__', '==SOLUTION=='},
        help=("Tags indicating which cells are to be removed")).tag(
            config=True)

    def is_solution(self, cell):
        """
        Checks that a cell has a solution tag. 
        """

        lines = set(cell.source.split("\n"))
        lines = {line.strip().replace(' ', '') for line in lines}

        return self.solution_tags.intersection(lines)

    def preprocess(self, nb, resources):

        nb_copy = deepcopy(nb)

        # Skip preprocessing if the list of patterns is empty
        if not self.solution_tags:
            return nb, resources

        # Filter out cells that meet the conditions
        cells = []
        for cell in nb_copy.cells:
            if self.is_solution(cell) or cell.cell_type == 'markdown':
                cells.append(self.preprocess_cell(cell))

        if len(nb_copy.cells) == len(cells):
            warn(
                "No lesson cells were found in the notebook!"
                " Double check the solution tag placement and formatting if this is not correct.",
                UserWarning)

        nb_copy.cells = cells

        return nb_copy, resources

    def preprocess_cell(self, cell):
        """
        Removes the solution tag from the solution cells.
        """

        lines = cell.source.split('\n')
        no_tags = [
            line for line in lines
            if line.strip().replace(' ', '') not in self.solution_tags
        ]
        cell.source = '\n'.join(no_tags)
        return cell
class Bar(Configurable):

    b = Integer(0, help="The integer b.").tag(config=True)
    enabled = Bool(True, help="Enable bar.").tag(config=True)
    tb = Tuple(()).tag(config=True, multiplicity="*")
    aset = Set().tag(config=True, multiplicity="+")
    bdict = Dict().tag(config=True)
    idict = Dict(value_trait=Integer()).tag(config=True)
    key_dict = Dict(per_key_traits={"i": Integer(), "b": Bytes()}).tag(config=True)
Exemple #8
0
class FileWhitelistMixin(LoggingConfigurable):
    """
    """

    #: The path of the whitelist file.
    whitelist_file = Unicode(config=True)

    #: When the file was last modified, so that we can reload appropriately.
    _whitelist_file_last_modified = Float()

    #: Cached whitelist to return every time the file hasn't changed.
    _whitelist = Set()

    @property
    def whitelist(self):
        """Returns the whitelist for the approved users.
        """
        # Note: we return a copy because other code in jupyterhub
        # will modify the set. We don't want our cache to be modified.
        try:
            cur_mtime = os.path.getmtime(self.whitelist_file)
            if cur_mtime <= self._whitelist_file_last_modified:
                # File older than last change.
                # keep using the current cached whitelist
                return set(self._whitelist)

            self.log.info("Whitelist file more recent than the old one. "
                          "Updating whitelist.")

            with open(self.whitelist_file, "r") as f:
                whitelisted_users = set(self.normalize_username(x.strip())
                                        for x in f.readlines()
                                        if not x.strip().startswith("#"))
        except FileNotFoundError:
            # empty set means everybody is allowed
            return set()
        except Exception:
            # For other exceptions, assume the file is broken, log it
            # and return what we have.
            self.log.exception("Unable to access whitelist.")
            return set(self._whitelist)

        self._whitelist = whitelisted_users
        self._whitelist_file_last_modified = cur_mtime

        return set(self._whitelist)

    @whitelist.setter
    def whitelist(self, value):
        """Dummy setter that does nothing.
        Jupyterhub assumes it can normalize the names and set them
        back. We can't let it perform this operation. """
        pass
Exemple #9
0
class CASLocalAuthenticator(LocalAuthenticator):
    """
    Validate a CAS service ticket and optionally check for the presence of an
    authorization attribute.
    """
    cas_login_url = Unicode(
        config=True,
        help="""The CAS URL to redirect unauthenticated users to.""")

    cas_logout_url = Unicode(
        config=True,
        help="""The CAS URL for logging out an authenticated user.""")

    cas_service_url = Unicode(
        allow_none=True,
        default_value=None,
        config=True,
        help=
        """The service URL the CAS server will redirect the browser back to on successful authentication."""
    )

    cas_client_ca_certs = Unicode(
        allow_none=True,
        default_value=None,
        config=True,
        help=
        """Path to CA certificates the CAS client will trust when validating a service ticket."""
    )

    cas_service_validate_url = Unicode(
        config=True,
        help="""The CAS endpoint for validating service tickets.""")

    cas_required_attribs = Set(
        help=
        "A set of attribute name and value tuples a user must have to be allowed access."
    ).tag(config=True)

    def get_handlers(self, app):
        return [
            (r'/login', CASLoginHandler),
            (r'/logout', CASLogoutHandler),
        ]

    @gen.coroutine
    def authenticate(self, *args):
        raise NotImplementedError()
Exemple #10
0
class ClearOutputPreprocessor(Preprocessor):
    """
    Removes the output from all code cells in a notebook.
    """

    remove_metadata_fields = Set({"collapsed", "scrolled"}).tag(config=True)

    def preprocess_cell(self, cell, resources, cell_index):
        """
        Apply a transformation on each cell. See base.py for details.
        """
        if cell.cell_type == "code":
            cell.outputs = []
            cell.execution_count = None
            # Remove metadata associated with output
            if "metadata" in cell:
                for field in self.remove_metadata_fields:
                    cell.metadata.pop(field, None)
        return cell, resources
Exemple #11
0
class AddCellIndex(DsCreatePreprocessor):

    description = '''
    AddCellIndex adds a metadata.index variable to a notebook and determines if a cell is a solution cell.
    This preprocessor is used primarily for ``--inline`` splits.
    '''
    index = Int(0)

    solution_tags = Set(
        {'#__SOLUTION__', '#==SOLUTION==', '__SOLUTION__', '==SOLUTION=='},
        help=("Tags indicating which cells are to be removed")).tag(
            config=True)

    def preprocess(self, nb, resources):

        nb_copy = deepcopy(nb)

        # Filter out cells that meet the conditions
        nb_copy.cells = [
            self.preprocess_cell(cell, resources, index)[0]
            for index, cell in enumerate(nb_copy.cells)
        ]

        return nb_copy, resources

    def preprocess_cell(self, cell, resources, cell_index):
        """
        No transformation is applied.
        """
        lines = set(cell.source.split("\n"))
        if self.solution_tags.intersection(lines):
            cell['metadata']['solution'] = True
        else:
            cell['metadata']['solution'] = False

        cell['metadata']['index'] = self.index
        self.index += 1
        return cell, resources
class NBViewer(Application):

    name = Unicode("NBViewer")

    aliases = Dict({
        "base-url": "NBViewer.base_url",
        "binder-base-url": "NBViewer.binder_base_url",
        "cache-expiry-max": "NBViewer.cache_expiry_max",
        "cache-expiry-min": "NBViewer.cache_expiry_min",
        "config-file": "NBViewer.config_file",
        "content-security-policy": "NBViewer.content_security_policy",
        "default-format": "NBViewer.default_format",
        "frontpage": "NBViewer.frontpage",
        "host": "NBViewer.host",
        "ipywidgets-base-url": "NBViewer.ipywidgets_base_url",
        "jupyter-js-widgets-version": "NBViewer.jupyter_js_widgets_version",
        "jupyter-widgets-html-manager-version":
        "NBViewer.jupyter_widgets_html_manager_version",
        "localfiles": "NBViewer.localfiles",
        "log-level": "Application.log_level",
        "mathjax-url": "NBViewer.mathjax_url",
        "mc-threads": "NBViewer.mc_threads",
        "port": "NBViewer.port",
        "processes": "NBViewer.processes",
        "provider-rewrites": "NBViewer.provider_rewrites",
        "providers": "NBViewer.providers",
        "proxy-host": "NBViewer.proxy_host",
        "proxy-port": "NBViewer.proxy_port",
        "rate-limit": "NBViewer.rate_limit",
        "rate-limit-interval": "NBViewer.rate_limit_interval",
        "render-timeout": "NBViewer.render_timeout",
        "sslcert": "NBViewer.sslcert",
        "sslkey": "NBViewer.sslkey",
        "static-path": "NBViewer.static_path",
        "static-url-prefix": "NBViewer.static_url_prefix",
        "statsd-host": "NBViewer.statsd_host",
        "statsd-port": "NBViewer.statsd_port",
        "statsd-prefix": "NBViewer.statsd_prefix",
        "template-path": "NBViewer.template_path",
        "threads": "NBViewer.threads",
    })

    flags = Dict({
        "debug": (
            {
                "Application": {
                    "log_level": logging.DEBUG
                }
            },
            "Set log-level to debug, for the most verbose logging.",
        ),
        "generate-config": (
            {
                "NBViewer": {
                    "generate_config": True
                }
            },
            "Generate default config file.",
        ),
        "localfile-any-user": (
            {
                "NBViewer": {
                    "localfile_any_user": True
                }
            },
            "Also serve files that are not readable by 'Other' on the local file system.",
        ),
        "localfile-follow-symlinks": (
            {
                "NBViewer": {
                    "localfile_follow_symlinks": True
                }
            },
            "Resolve/follow symbolic links to their target file using realpath.",
        ),
        "no-cache": ({
            "NBViewer": {
                "no_cache": True
            }
        }, "Do not cache results."),
        "no-check-certificate": (
            {
                "NBViewer": {
                    "no_check_certificate": True
                }
            },
            "Do not validate SSL certificates.",
        ),
        "y": (
            {
                "NBViewer": {
                    "answer_yes": True
                }
            },
            "Answer yes to any questions (e.g. confirm overwrite).",
        ),
        "yes": (
            {
                "NBViewer": {
                    "answer_yes": True
                }
            },
            "Answer yes to any questions (e.g. confirm overwrite).",
        ),
    })

    # Use this to insert custom configuration of handlers for NBViewer extensions
    handler_settings = Dict().tag(config=True)

    create_handler = Unicode(
        default_value="nbviewer.handlers.CreateHandler",
        help="The Tornado handler to use for creation via frontpage form.",
    ).tag(config=True)
    custom404_handler = Unicode(
        default_value="nbviewer.handlers.Custom404",
        help="The Tornado handler to use for rendering 404 templates.",
    ).tag(config=True)
    faq_handler = Unicode(
        default_value="nbviewer.handlers.FAQHandler",
        help=
        "The Tornado handler to use for rendering and viewing the FAQ section.",
    ).tag(config=True)
    gist_handler = Unicode(
        default_value="nbviewer.providers.gist.handlers.GistHandler",
        help=
        "The Tornado handler to use for viewing notebooks stored as GitHub Gists",
    ).tag(config=True)
    github_blob_handler = Unicode(
        default_value="nbviewer.providers.github.handlers.GitHubBlobHandler",
        help=
        "The Tornado handler to use for viewing notebooks stored as blobs on GitHub",
    ).tag(config=True)
    github_tree_handler = Unicode(
        default_value="nbviewer.providers.github.handlers.GitHubTreeHandler",
        help="The Tornado handler to use for viewing directory trees on GitHub",
    ).tag(config=True)
    github_user_handler = Unicode(
        default_value="nbviewer.providers.github.handlers.GitHubUserHandler",
        help=
        "The Tornado handler to use for viewing all of a user's repositories on GitHub.",
    ).tag(config=True)
    index_handler = Unicode(
        default_value="nbviewer.handlers.IndexHandler",
        help="The Tornado handler to use for rendering the frontpage section.",
    ).tag(config=True)
    local_handler = Unicode(
        default_value="nbviewer.providers.local.handlers.LocalFileHandler",
        help=
        "The Tornado handler to use for viewing notebooks found on a local filesystem",
    ).tag(config=True)
    url_handler = Unicode(
        default_value="nbviewer.providers.url.handlers.URLHandler",
        help=
        "The Tornado handler to use for viewing notebooks accessed via URL",
    ).tag(config=True)
    user_gists_handler = Unicode(
        default_value="nbviewer.providers.gist.handlers.UserGistsHandler",
        help=
        "The Tornado handler to use for viewing directory containing all of a user's Gists",
    ).tag(config=True)

    answer_yes = Bool(
        default_value=False,
        help="Answer yes to any questions (e.g. confirm overwrite).",
    ).tag(config=True)

    # base_url specified by the user
    base_url = Unicode(default_value="/",
                       help="URL base for the server").tag(config=True)

    binder_base_url = Unicode(
        default_value="https://mybinder.org/v2",
        help="URL base for binder notebook execution service.",
    ).tag(config=True)

    cache_expiry_max = Int(
        default_value=2 * 60 * 60,
        help="Maximum cache expiry (seconds).").tag(config=True)

    cache_expiry_min = Int(
        default_value=10 * 60,
        help="Minimum cache expiry (seconds).").tag(config=True)

    client = Any().tag(config=True)

    @default("client")
    def _default_client(self):
        client = HTTPClientClass(log=self.log)
        client.cache = self.cache
        return client

    config_file = Unicode(default_value="nbviewer_config.py",
                          help="The config file to load.").tag(config=True)

    content_security_policy = Unicode(
        default_value="connect-src 'none';",
        help="Content-Security-Policy header setting.",
    ).tag(config=True)

    default_format = Unicode(
        default_value="html",
        help="Format to use for legacy / URLs.").tag(config=True)

    frontpage = Unicode(
        default_value=FRONTPAGE_JSON,
        help="Path to json file containing frontpage content.",
    ).tag(config=True)

    generate_config = Bool(
        default_value=False,
        help="Generate default config file.").tag(config=True)

    host = Unicode(help="Run on the given interface.").tag(config=True)

    @default("host")
    def _default_host(self):
        return self.default_endpoint["host"]

    index = Any().tag(config=True)

    @default("index")
    def _load_index(self):
        if os.environ.get("NBINDEX_PORT"):
            self.log.info("Indexing notebooks")
            tcp_index = os.environ.get("NBINDEX_PORT")
            index_url = tcp_index.split("tcp://")[1]
            index_host, index_port = index_url.split(":")
        else:
            self.log.info("Not indexing notebooks")
            indexer = NoSearch()
        return indexer

    ipywidgets_base_url = Unicode(
        default_value="https://unpkg.com/",
        help="URL base for ipywidgets JS package.").tag(config=True)

    jupyter_js_widgets_version = Unicode(
        default_value="*",
        help="Version specifier for jupyter-js-widgets JS package.").tag(
            config=True)

    jupyter_widgets_html_manager_version = Unicode(
        default_value="*",
        help="Version specifier for @jupyter-widgets/html-manager JS package.",
    ).tag(config=True)

    localfile_any_user = Bool(
        default_value=False,
        help=
        "Also serve files that are not readable by 'Other' on the local file system.",
    ).tag(config=True)

    localfile_follow_symlinks = Bool(
        default_value=False,
        help=
        "Resolve/follow symbolic links to their target file using realpath.",
    ).tag(config=True)

    localfiles = Unicode(
        default_value="",
        help=
        "Allow to serve local files under /localfile/* this can be a security risk.",
    ).tag(config=True)

    mathjax_url = Unicode(
        default_value="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/",
        help="URL base for mathjax package.",
    ).tag(config=True)

    # cache frontpage links for the maximum allowed time
    max_cache_uris = Set().tag(config=True)

    @default("max_cache_uris")
    def _load_max_cache_uris(self):
        max_cache_uris = {""}
        for section in self.frontpage_setup["sections"]:
            for link in section["links"]:
                max_cache_uris.add("/" + link["target"])
        return max_cache_uris

    mc_threads = Int(
        default_value=1,
        help="Number of threads to use for Async Memcache.").tag(config=True)

    no_cache = Bool(default_value=False,
                    help="Do not cache results.").tag(config=True)

    no_check_certificate = Bool(
        default_value=False,
        help="Do not validate SSL certificates.").tag(config=True)

    port = Int(help="Run on the given port.").tag(config=True)

    @default("port")
    def _default_port(self):
        return self.default_endpoint["port"]

    processes = Int(
        default_value=0,
        help="Use processes instead of threads for rendering.").tag(
            config=True)

    provider_rewrites = List(
        trait=Unicode,
        default_value=default_rewrites,
        help="Full dotted package(s) that provide `uri_rewrites`.",
    ).tag(config=True)

    providers = List(
        trait=Unicode,
        default_value=default_providers,
        help="Full dotted package(s) that provide `default_handlers`.",
    ).tag(config=True)

    proxy_host = Unicode(default_value="",
                         help="The proxy URL.").tag(config=True)

    proxy_port = Int(default_value=-1, help="The proxy port.").tag(config=True)

    rate_limit = Int(
        default_value=60,
        help=
        "Number of requests to allow in rate_limit_interval before limiting. Only requests that trigger a new render are counted.",
    ).tag(config=True)

    rate_limit_interval = Int(
        default_value=600,
        help="Interval (in seconds) for rate limiting.").tag(config=True)

    render_timeout = Int(
        default_value=15,
        help=
        "Time to wait for a render to complete before showing the 'Working...' page.",
    ).tag(config=True)

    sslcert = Unicode(help="Path to ssl .crt file.").tag(config=True)

    sslkey = Unicode(help="Path to ssl .key file.").tag(config=True)

    static_path = Unicode(
        default_value=os.environ.get("NBVIEWER_STATIC_PATH", ""),
        help="Custom path for loading additional static files.",
    ).tag(config=True)

    static_url_prefix = Unicode(default_value="/static/").tag(config=True)
    # Not exposed to end user for configuration, since needs to access base_url
    _static_url_prefix = Unicode()

    @default("_static_url_prefix")
    def _load_static_url_prefix(self):
        # Last '/' ensures that NBViewer still works regardless of whether user chooses e.g. '/static2/' or '/static2' as their custom prefix
        return url_path_join(self._base_url, self.static_url_prefix, "/")

    statsd_host = Unicode(
        default_value="",
        help="Host running statsd to send metrics to.").tag(config=True)

    statsd_port = Int(
        default_value=8125,
        help="Port on which statsd is listening for metrics on statsd_host.",
    ).tag(config=True)

    statsd_prefix = Unicode(
        default_value="nbviewer",
        help="Prefix to use for naming metrics sent to statsd.",
    ).tag(config=True)

    template_path = Unicode(
        default_value=os.environ.get("NBVIEWER_TEMPLATE_PATH", ""),
        help=
        "Custom template path for the nbviewer app (not rendered notebooks).",
    ).tag(config=True)

    threads = Int(
        default_value=1,
        help="Number of threads to use for rendering.").tag(config=True)

    # prefer the JupyterHub defined service prefix over the CLI
    @cached_property
    def _base_url(self):
        return os.getenv("JUPYTERHUB_SERVICE_PREFIX", self.base_url)

    @cached_property
    def cache(self):
        memcache_urls = os.environ.get("MEMCACHIER_SERVERS",
                                       os.environ.get("MEMCACHE_SERVERS"))
        # Handle linked Docker containers
        if os.environ.get("NBCACHE_PORT"):
            tcp_memcache = os.environ.get("NBCACHE_PORT")
            memcache_urls = tcp_memcache.split("tcp://")[1]
        if self.no_cache:
            self.log.info("Not using cache")
            cache = MockCache()
        elif pylibmc and memcache_urls:
            # setup memcache
            mc_pool = ThreadPoolExecutor(self.mc_threads)
            kwargs = dict(pool=mc_pool)
            username = os.environ.get("MEMCACHIER_USERNAME", "")
            password = os.environ.get("MEMCACHIER_PASSWORD", "")
            if username and password:
                kwargs["binary"] = True
                kwargs["username"] = username
                kwargs["password"] = password
                self.log.info("Using SASL memcache")
            else:
                self.log.info("Using plain memcache")

            cache = AsyncMultipartMemcache(memcache_urls.split(","), **kwargs)
        else:
            self.log.info("Using in-memory cache")
            cache = DummyAsyncCache()

        return cache

    @cached_property
    def default_endpoint(self):
        # check if JupyterHub service options are available to use as defaults
        if "JUPYTERHUB_SERVICE_URL" in os.environ:
            url = urlparse(os.environ["JUPYTERHUB_SERVICE_URL"])
            default_host, default_port = url.hostname, url.port
        else:
            default_host, default_port = "0.0.0.0", 5000
        return {"host": default_host, "port": default_port}

    @cached_property
    def env(self):
        env = Environment(loader=FileSystemLoader(self.template_paths),
                          autoescape=True)
        env.filters["markdown"] = markdown.markdown
        try:
            git_data = git_info(here)
        except Exception as e:
            self.log.error("Failed to get git info: %s", e)
            git_data = {}
        else:
            git_data["msg"] = escape(git_data["msg"])

        if self.no_cache:
            # force Jinja2 to recompile template every time
            env.globals.update(cache_size=0)
        env.globals.update(
            nrhead=nrhead,
            nrfoot=nrfoot,
            git_data=git_data,
            jupyter_info=jupyter_info(),
            len=len,
        )

        return env

    @cached_property
    def fetch_kwargs(self):
        fetch_kwargs = dict(connect_timeout=10)
        if self.proxy_host:
            fetch_kwargs.update(proxy_host=self.proxy_host,
                                proxy_port=self.proxy_port)
            self.log.info("Using web proxy {proxy_host}:{proxy_port}."
                          "".format(**fetch_kwargs))

        if self.no_check_certificate:
            fetch_kwargs.update(validate_cert=False)
            self.log.info("Not validating SSL certificates")

        return fetch_kwargs

    @cached_property
    def formats(self):
        return self.configure_formats()

    # load frontpage sections
    @cached_property
    def frontpage_setup(self):
        with io.open(self.frontpage, "r") as f:
            frontpage_setup = json.load(f)
        # check if the JSON has a 'sections' field, otherwise assume it is just a list of sessions,
        # and provide the defaults of the other fields
        if "sections" not in frontpage_setup:
            frontpage_setup = {
                "title": "nbviewer",
                "subtitle": "A simple way to share Jupyter notebooks",
                "show_input": True,
                "sections": frontpage_setup,
            }
        return frontpage_setup

    # Attribute inherited from traitlets.config.Application, automatically used to style logs
    # https://github.com/ipython/traitlets/blob/master/traitlets/config/application.py#L191
    _log_formatter_cls = LogFormatter
    # Need Tornado LogFormatter for color logs, keys 'color' and 'end_color' in log_format

    # Observed traitlet inherited again from traitlets.config.Application
    # https://github.com/ipython/traitlets/blob/master/traitlets/config/application.py#L177
    @default("log_level")
    def _log_level_default(self):
        return logging.INFO

    # Ditto the above: https://github.com/ipython/traitlets/blob/master/traitlets/config/application.py#L197
    @default("log_format")
    def _log_format_default(self):
        """override default log format to include time and color, plus to always display the log level, not just when it's high"""
        return "%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s %(module)s:%(lineno)d]%(end_color)s %(message)s"

    # For consistency with JupyterHub logs
    @default("log_datefmt")
    def _log_datefmt_default(self):
        """Exclude date from default date format"""
        return "%Y-%m-%d %H:%M:%S"

    @cached_property
    def pool(self):
        if self.processes:
            pool = ProcessPoolExecutor(self.processes)
        else:
            pool = ThreadPoolExecutor(self.threads)
        return pool

    @cached_property
    def rate_limiter(self):
        rate_limiter = RateLimiter(limit=self.rate_limit,
                                   interval=self.rate_limit_interval,
                                   cache=self.cache)
        return rate_limiter

    @cached_property
    def static_paths(self):
        default_static_path = pjoin(here, "static")
        if self.static_path:
            self.log.info("Using custom static path {}".format(
                self.static_path))
            static_paths = [self.static_path, default_static_path]
        else:
            static_paths = [default_static_path]

        return static_paths

    @cached_property
    def template_paths(self):
        default_template_path = pjoin(here, "templates")
        if self.template_path:
            self.log.info("Using custom template path {}".format(
                self.template_path))
            template_paths = [self.template_path, default_template_path]
        else:
            template_paths = [default_template_path]

        return template_paths

    def configure_formats(self, formats=None):
        """
        Format-specific configuration.
        """
        if formats is None:
            formats = default_formats()

        # This would be better defined in a class
        self.config.HTMLExporter.template_file = "basic"
        self.config.SlidesExporter.template_file = "slides_reveal"

        self.config.TemplateExporter.template_path = [
            os.path.join(os.path.dirname(__file__), "templates", "nbconvert")
        ]

        for key, format in formats.items():
            exporter_cls = format.get("exporter", exporter_map[key])
            if self.processes:
                # can't pickle exporter instances,
                formats[key]["exporter"] = exporter_cls
            else:
                formats[key]["exporter"] = exporter_cls(config=self.config,
                                                        log=self.log)

        return formats

    def init_tornado_application(self):
        # handle handlers
        handler_names = dict(
            create_handler=self.create_handler,
            custom404_handler=self.custom404_handler,
            faq_handler=self.faq_handler,
            gist_handler=self.gist_handler,
            github_blob_handler=self.github_blob_handler,
            github_tree_handler=self.github_tree_handler,
            github_user_handler=self.github_user_handler,
            index_handler=self.index_handler,
            local_handler=self.local_handler,
            url_handler=self.url_handler,
            user_gists_handler=self.user_gists_handler,
        )
        handler_kwargs = {
            "handler_names": handler_names,
            "handler_settings": self.handler_settings,
        }

        handlers = init_handlers(self.formats, self.providers, self._base_url,
                                 self.localfiles, **handler_kwargs)

        # NBConvert config
        self.config.NbconvertApp.fileext = "html"
        self.config.CSSHTMLHeaderTransformer.enabled = False

        # DEBUG env implies both autoreload and log-level
        if os.environ.get("DEBUG"):
            self.log.setLevel(logging.DEBUG)

        # input traitlets to settings
        settings = dict(
            # Allow FileFindHandler to load static directories from e.g. a Docker container
            allow_remote_access=True,
            base_url=self._base_url,
            binder_base_url=self.binder_base_url,
            cache=self.cache,
            cache_expiry_max=self.cache_expiry_max,
            cache_expiry_min=self.cache_expiry_min,
            client=self.client,
            config=self.config,
            content_security_policy=self.content_security_policy,
            default_format=self.default_format,
            fetch_kwargs=self.fetch_kwargs,
            formats=self.formats,
            frontpage_setup=self.frontpage_setup,
            google_analytics_id=os.getenv("GOOGLE_ANALYTICS_ID"),
            gzip=True,
            hub_api_token=os.getenv("JUPYTERHUB_API_TOKEN"),
            hub_api_url=os.getenv("JUPYTERHUB_API_URL"),
            hub_base_url=os.getenv("JUPYTERHUB_BASE_URL"),
            index=self.index,
            ipywidgets_base_url=self.ipywidgets_base_url,
            jinja2_env=self.env,
            jupyter_js_widgets_version=self.jupyter_js_widgets_version,
            jupyter_widgets_html_manager_version=self.
            jupyter_widgets_html_manager_version,
            localfile_any_user=self.localfile_any_user,
            localfile_follow_symlinks=self.localfile_follow_symlinks,
            localfile_path=os.path.abspath(self.localfiles),
            log=self.log,
            log_function=log_request,
            mathjax_url=self.mathjax_url,
            max_cache_uris=self.max_cache_uris,
            pool=self.pool,
            provider_rewrites=self.provider_rewrites,
            providers=self.providers,
            rate_limiter=self.rate_limiter,
            render_timeout=self.render_timeout,
            static_handler_class=StaticFileHandler,
            # FileFindHandler expects list of static paths, so self.static_path*s* is correct
            static_path=self.static_paths,
            static_url_prefix=self._static_url_prefix,
            statsd_host=self.statsd_host,
            statsd_port=self.statsd_port,
            statsd_prefix=self.statsd_prefix,
        )

        if self.localfiles:
            self.log.warning(
                "Serving local notebooks in %s, this can be a security risk",
                self.localfiles,
            )

        # create the app
        self.tornado_application = web.Application(handlers, **settings)

    def init_logging(self):

        # Note that we inherit a self.log attribute from traitlets.config.Application
        # https://github.com/ipython/traitlets/blob/master/traitlets/config/application.py#L209
        # as well as a log_level attribute
        # https://github.com/ipython/traitlets/blob/master/traitlets/config/application.py#L177

        # This prevents double log messages because tornado use a root logger that
        # self.log is a child of. The logging module dispatches log messages to a log
        # and all of its ancestors until propagate is set to False.
        self.log.propagate = False

        tornado_log = logging.getLogger("tornado")
        # hook up tornado's loggers to our app handlers
        for log in (app_log, access_log, tornado_log, curl_log):
            # ensure all log statements identify the application they come from
            log.name = self.log.name
            log.parent = self.log
            log.propagate = True
            log.setLevel(self.log_level)

        # disable curl debug, which logs all headers, info for upstream requests, which is TOO MUCH
        curl_log.setLevel(max(self.log_level, logging.INFO))

    # Mostly copied from JupyterHub because if it isn't broken then don't fix it.
    def write_config_file(self):
        """Write our default config to a .py config file"""
        config_file_dir = os.path.dirname(os.path.abspath(self.config_file))
        if not os.path.isdir(config_file_dir):
            self.exit(
                "{} does not exist. The destination directory must exist before generating config file."
                .format(config_file_dir))
        if os.path.exists(self.config_file) and not self.answer_yes:
            answer = ""

            def ask():
                prompt = "Overwrite %s with default config? [y/N]" % self.config_file
                try:
                    return input(prompt).lower() or "n"
                except KeyboardInterrupt:
                    print("")  # empty line
                    return "n"

            answer = ask()
            while not answer.startswith(("y", "n")):
                print("Please answer 'yes' or 'no'")
                answer = ask()
            if answer.startswith("n"):
                self.exit("Not overwriting config file with default.")

        # Inherited method from traitlets.config.Application
        config_text = self.generate_config_file()
        if isinstance(config_text, bytes):
            config_text = config_text.decode("utf8")
        print("Writing default config to: %s" % self.config_file)
        with open(self.config_file, mode="w") as f:
            f.write(config_text)
        self.exit("Wrote default config file.")

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # parse command line with catch_config_error from traitlets.config.Application
        super().initialize(*args, **kwargs)

        if self.generate_config:
            self.write_config_file()

        # Inherited method from traitlets.config.Application
        self.load_config_file(self.config_file)
        self.init_logging()
        self.init_tornado_application()
Exemple #13
0
class JupyterHub(Application):
    """An Application for starting a Multi-User Jupyter Notebook server."""
    name = 'jupyterhub'
    version = jupyterhub.__version__
    
    description = """Start a multi-user Jupyter Notebook server
    
    Spawns a configurable-http-proxy and multi-user Hub,
    which authenticates users and spawns single-user Notebook servers
    on behalf of users.
    """
    
    examples = """
    
    generate default config file:
    
        jupyterhub --generate-config -f /etc/jupyterhub/jupyterhub.py
    
    spawn the server on 10.0.1.2:443 with https:
    
        jupyterhub --ip 10.0.1.2 --port 443 --ssl-key my_ssl.key --ssl-cert my_ssl.cert
    """
    
    aliases = Dict(aliases)
    flags = Dict(flags)
    
    subcommands = {
        'token': (NewToken, "Generate an API token for a user")
    }
    
    classes = List([
        Spawner,
        LocalProcessSpawner,
        Authenticator,
        PAMAuthenticator,
    ])
    
    config_file = Unicode('jupyterhub_config.py', config=True,
        help="The config file to load",
    )
    generate_config = Bool(False, config=True,
        help="Generate default config file",
    )
    answer_yes = Bool(False, config=True,
        help="Answer yes to any questions (e.g. confirm overwrite)"
    )
    pid_file = Unicode('', config=True,
        help="""File to write PID
        Useful for daemonizing jupyterhub.
        """
    )
    cookie_max_age_days = Float(14, config=True,
        help="""Number of days for a login cookie to be valid.
        Default is two weeks.
        """
    )
    last_activity_interval = Integer(300, config=True,
        help="Interval (in seconds) at which to update last-activity timestamps."
    )
    proxy_check_interval = Integer(30, config=True,
        help="Interval (in seconds) at which to check if the proxy is running."
    )
    
    data_files_path = Unicode(DATA_FILES_PATH, config=True,
        help="The location of jupyterhub data files (e.g. /usr/local/share/jupyter/hub)"
    )

    template_paths = List(
        config=True,
        help="Paths to search for jinja templates.",
    )

    def _template_paths_default(self):
        return [os.path.join(self.data_files_path, 'templates')]

    ssl_key = Unicode('', config=True,
        help="""Path to SSL key file for the public facing interface of the proxy
        
        Use with ssl_cert
        """
    )
    ssl_cert = Unicode('', config=True,
        help="""Path to SSL certificate file for the public facing interface of the proxy
        
        Use with ssl_key
        """
    )
    ip = Unicode('', config=True,
        help="The public facing ip of the proxy"
    )
    port = Integer(8000, config=True,
        help="The public facing port of the proxy"
    )
    base_url = URLPrefix('/', config=True,
        help="The base URL of the entire application"
    )
    
    jinja_environment_options = Dict(config=True,
        help="Supply extra arguments that will be passed to Jinja environment."
    )
    
    proxy_cmd = Command('configurable-http-proxy', config=True,
        help="""The command to start the http proxy.
        
        Only override if configurable-http-proxy is not on your PATH
        """
    )
    debug_proxy = Bool(False, config=True, help="show debug output in configurable-http-proxy")
    proxy_auth_token = Unicode(config=True,
        help="""The Proxy Auth token.

        Loaded from the CONFIGPROXY_AUTH_TOKEN env variable by default.
        """
    )
    def _proxy_auth_token_default(self):
        token = os.environ.get('CONFIGPROXY_AUTH_TOKEN', None)
        if not token:
            self.log.warn('\n'.join([
                "",
                "Generating CONFIGPROXY_AUTH_TOKEN. Restarting the Hub will require restarting the proxy.",
                "Set CONFIGPROXY_AUTH_TOKEN env or JupyterHub.proxy_auth_token config to avoid this message.",
                "",
            ]))
            token = orm.new_token()
        return token
    
    proxy_api_ip = Unicode('localhost', config=True,
        help="The ip for the proxy API handlers"
    )
    proxy_api_port = Integer(config=True,
        help="The port for the proxy API handlers"
    )
    def _proxy_api_port_default(self):
        return self.port + 1
    
    hub_port = Integer(8081, config=True,
        help="The port for this process"
    )
    hub_ip = Unicode('localhost', config=True,
        help="The ip for this process"
    )
    
    hub_prefix = URLPrefix('/hub/', config=True,
        help="The prefix for the hub server. Must not be '/'"
    )
    def _hub_prefix_default(self):
        return url_path_join(self.base_url, '/hub/')
    
    def _hub_prefix_changed(self, name, old, new):
        if new == '/':
            raise TraitError("'/' is not a valid hub prefix")
        if not new.startswith(self.base_url):
            self.hub_prefix = url_path_join(self.base_url, new)
    
    cookie_secret = Bytes(config=True, env='JPY_COOKIE_SECRET',
        help="""The cookie secret to use to encrypt cookies.

        Loaded from the JPY_COOKIE_SECRET env variable by default.
        """
    )
    
    cookie_secret_file = Unicode('jupyterhub_cookie_secret', config=True,
        help="""File in which to store the cookie secret."""
    )
    
    authenticator_class = Type(PAMAuthenticator, Authenticator,
        config=True,
        help="""Class for authenticating users.
        
        This should be a class with the following form:
        
        - constructor takes one kwarg: `config`, the IPython config object.
        
        - is a tornado.gen.coroutine
        - returns username on success, None on failure
        - takes two arguments: (handler, data),
          where `handler` is the calling web.RequestHandler,
          and `data` is the POST form data from the login page.
        """
    )
    
    authenticator = Instance(Authenticator)
    def _authenticator_default(self):
        return self.authenticator_class(parent=self, db=self.db)

    # class for spawning single-user servers
    spawner_class = Type(LocalProcessSpawner, Spawner,
        config=True,
        help="""The class to use for spawning single-user servers.
        
        Should be a subclass of Spawner.
        """
    )
    
    db_url = Unicode('sqlite:///jupyterhub.sqlite', config=True,
        help="url for the database. e.g. `sqlite:///jupyterhub.sqlite`"
    )
    def _db_url_changed(self, name, old, new):
        if '://' not in new:
            # assume sqlite, if given as a plain filename
            self.db_url = 'sqlite:///%s' % new

    db_kwargs = Dict(config=True,
        help="""Include any kwargs to pass to the database connection.
        See sqlalchemy.create_engine for details.
        """
    )

    reset_db = Bool(False, config=True,
        help="Purge and reset the database."
    )
    debug_db = Bool(False, config=True,
        help="log all database transactions. This has A LOT of output"
    )
    session_factory = Any()
    
    admin_access = Bool(False, config=True,
        help="""Grant admin users permission to access single-user servers.
        
        Users should be properly informed if this is enabled.
        """
    )
    admin_users = Set(config=True,
        help="""DEPRECATED, use Authenticator.admin_users instead."""
    )
    
    tornado_settings = Dict(config=True)
    
    cleanup_servers = Bool(True, config=True,
        help="""Whether to shutdown single-user servers when the Hub shuts down.
        
        Disable if you want to be able to teardown the Hub while leaving the single-user servers running.
        
        If both this and cleanup_proxy are False, sending SIGINT to the Hub will
        only shutdown the Hub, leaving everything else running.
        
        The Hub should be able to resume from database state.
        """
    )

    cleanup_proxy = Bool(True, config=True,
        help="""Whether to shutdown the proxy when the Hub shuts down.
        
        Disable if you want to be able to teardown the Hub while leaving the proxy running.
        
        Only valid if the proxy was starting by the Hub process.
        
        If both this and cleanup_servers are False, sending SIGINT to the Hub will
        only shutdown the Hub, leaving everything else running.
        
        The Hub should be able to resume from database state.
        """
    )
    
    handlers = List()
    
    _log_formatter_cls = CoroutineLogFormatter
    http_server = None
    proxy_process = None
    io_loop = None
    
    def _log_level_default(self):
        return logging.INFO
    
    def _log_datefmt_default(self):
        """Exclude date from default date format"""
        return "%Y-%m-%d %H:%M:%S"

    def _log_format_default(self):
        """override default log format to include time"""
        return "%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s %(module)s:%(lineno)d]%(end_color)s %(message)s"

    extra_log_file = Unicode(
        "",
        config=True,
        help="Set a logging.FileHandler on this file."
    )
    extra_log_handlers = List(
        Instance(logging.Handler),
        config=True,
        help="Extra log handlers to set on JupyterHub logger",
    )

    def init_logging(self):
        # This prevents double log messages because tornado use a root logger that
        # self.log is a child of. The logging module dipatches log messages to a log
        # and all of its ancenstors until propagate is set to False.
        self.log.propagate = False

        if self.extra_log_file:
            self.extra_log_handlers.append(
                logging.FileHandler(self.extra_log_file)
            )

        _formatter = self._log_formatter_cls(
            fmt=self.log_format,
            datefmt=self.log_datefmt,
        )
        for handler in self.extra_log_handlers:
            if handler.formatter is None:
                handler.setFormatter(_formatter)
            self.log.addHandler(handler)

        # hook up tornado 3's loggers to our app handlers
        for log in (app_log, access_log, gen_log):
            # ensure all log statements identify the application they come from
            log.name = self.log.name
        logger = logging.getLogger('tornado')
        logger.propagate = True
        logger.parent = self.log
        logger.setLevel(self.log.level)

    def init_ports(self):
        if self.hub_port == self.port:
            raise TraitError("The hub and proxy cannot both listen on port %i" % self.port)
        if self.hub_port == self.proxy_api_port:
            raise TraitError("The hub and proxy API cannot both listen on port %i" % self.hub_port)
        if self.proxy_api_port == self.port:
            raise TraitError("The proxy's public and API ports cannot both be %i" % self.port)
    
    @staticmethod
    def add_url_prefix(prefix, handlers):
        """add a url prefix to handlers"""
        for i, tup in enumerate(handlers):
            lis = list(tup)
            lis[0] = url_path_join(prefix, tup[0])
            handlers[i] = tuple(lis)
        return handlers
    
    def init_handlers(self):
        h = []
        h.extend(handlers.default_handlers)
        h.extend(apihandlers.default_handlers)
        # load handlers from the authenticator
        h.extend(self.authenticator.get_handlers(self))

        self.handlers = self.add_url_prefix(self.hub_prefix, h)

        # some extra handlers, outside hub_prefix
        self.handlers.extend([
            (r"%s" % self.hub_prefix.rstrip('/'), web.RedirectHandler,
                {
                    "url": self.hub_prefix,
                    "permanent": False,
                }
            ),
            (r"(?!%s).*" % self.hub_prefix, handlers.PrefixRedirectHandler),
            (r'(.*)', handlers.Template404),
        ])
    
    def _check_db_path(self, path):
        """More informative log messages for failed filesystem access"""
        path = os.path.abspath(path)
        parent, fname = os.path.split(path)
        user = getuser()
        if not os.path.isdir(parent):
            self.log.error("Directory %s does not exist", parent)
        if os.path.exists(parent) and not os.access(parent, os.W_OK):
            self.log.error("%s cannot create files in %s", user, parent)
        if os.path.exists(path) and not os.access(path, os.W_OK):
            self.log.error("%s cannot edit %s", user, path)
    
    def init_secrets(self):
        trait_name = 'cookie_secret'
        trait = self.traits()[trait_name]
        env_name = trait.get_metadata('env')
        secret_file = os.path.abspath(
            os.path.expanduser(self.cookie_secret_file)
        )
        secret = self.cookie_secret
        secret_from = 'config'
        # load priority: 1. config, 2. env, 3. file
        if not secret and os.environ.get(env_name):
            secret_from = 'env'
            self.log.info("Loading %s from env[%s]", trait_name, env_name)
            secret = binascii.a2b_hex(os.environ[env_name])
        if not secret and os.path.exists(secret_file):
            secret_from = 'file'
            perm = os.stat(secret_file).st_mode
            if perm & 0o077:
                self.log.error("Bad permissions on %s", secret_file)
            else:
                self.log.info("Loading %s from %s", trait_name, secret_file)
                with open(secret_file) as f:
                    b64_secret = f.read()
                try:
                    secret = binascii.a2b_base64(b64_secret)
                except Exception as e:
                    self.log.error("%s does not contain b64 key: %s", secret_file, e)
        if not secret:
            secret_from = 'new'
            self.log.debug("Generating new %s", trait_name)
            secret = os.urandom(SECRET_BYTES)
        
        if secret_file and secret_from == 'new':
            # if we generated a new secret, store it in the secret_file
            self.log.info("Writing %s to %s", trait_name, secret_file)
            b64_secret = binascii.b2a_base64(secret).decode('ascii')
            with open(secret_file, 'w') as f:
                f.write(b64_secret)
            try:
                os.chmod(secret_file, 0o600)
            except OSError:
                self.log.warn("Failed to set permissions on %s", secret_file)
        # store the loaded trait value
        self.cookie_secret = secret
    
    # thread-local storage of db objects
    _local = Instance(threading.local, ())
    @property
    def db(self):
        if not hasattr(self._local, 'db'):
            self._local.db = scoped_session(self.session_factory)()
        return self._local.db
    
    @property
    def hub(self):
        if not getattr(self._local, 'hub', None):
            q = self.db.query(orm.Hub)
            assert q.count() <= 1
            self._local.hub = q.first()
        return self._local.hub
    
    @hub.setter
    def hub(self, hub):
        self._local.hub = hub
    
    @property
    def proxy(self):
        if not getattr(self._local, 'proxy', None):
            q = self.db.query(orm.Proxy)
            assert q.count() <= 1
            p = self._local.proxy = q.first()
            if p:
                p.auth_token = self.proxy_auth_token
        return self._local.proxy
    
    @proxy.setter
    def proxy(self, proxy):
        self._local.proxy = proxy
    
    def init_db(self):
        """Create the database connection"""
        self.log.debug("Connecting to db: %s", self.db_url)
        try:
            self.session_factory = orm.new_session_factory(
                self.db_url,
                reset=self.reset_db,
                echo=self.debug_db,
                **self.db_kwargs
            )
            # trigger constructing thread local db property
            _ = self.db
        except OperationalError as e:
            self.log.error("Failed to connect to db: %s", self.db_url)
            self.log.debug("Database error was:", exc_info=True)
            if self.db_url.startswith('sqlite:///'):
                self._check_db_path(self.db_url.split(':///', 1)[1])
            self.exit(1)
    
    def init_hub(self):
        """Load the Hub config into the database"""
        self.hub = self.db.query(orm.Hub).first()
        if self.hub is None:
            self.hub = orm.Hub(
                server=orm.Server(
                    ip=self.hub_ip,
                    port=self.hub_port,
                    base_url=self.hub_prefix,
                    cookie_name='jupyter-hub-token',
                )
            )
            self.db.add(self.hub)
        else:
            server = self.hub.server
            server.ip = self.hub_ip
            server.port = self.hub_port
            server.base_url = self.hub_prefix

        self.db.commit()
    
    @gen.coroutine
    def init_users(self):
        """Load users into and from the database"""
        db = self.db
        
        if self.admin_users and not self.authenticator.admin_users:
            self.log.warn(
                "\nJupyterHub.admin_users is deprecated."
                "\nUse Authenticator.admin_users instead."
            )
            self.authenticator.admin_users = self.admin_users
        admin_users = self.authenticator.admin_users
        
        if not admin_users:
            # add current user as admin if there aren't any others
            admins = db.query(orm.User).filter(orm.User.admin==True)
            if admins.first() is None:
                admin_users.add(getuser())
        
        new_users = []

        for name in admin_users:
            # ensure anyone specified as admin in config is admin in db
            user = orm.User.find(db, name)
            if user is None:
                user = orm.User(name=name, admin=True)
                new_users.append(user)
                db.add(user)
            else:
                user.admin = True

        # the admin_users config variable will never be used after this point.
        # only the database values will be referenced.

        whitelist = self.authenticator.whitelist

        if not whitelist:
            self.log.info("Not using whitelist. Any authenticated user will be allowed.")

        # add whitelisted users to the db
        for name in whitelist:
            user = orm.User.find(db, name)
            if user is None:
                user = orm.User(name=name)
                new_users.append(user)
                db.add(user)

        if whitelist:
            # fill the whitelist with any users loaded from the db,
            # so we are consistent in both directions.
            # This lets whitelist be used to set up initial list,
            # but changes to the whitelist can occur in the database,
            # and persist across sessions.
            for user in db.query(orm.User):
                whitelist.add(user.name)

        # The whitelist set and the users in the db are now the same.
        # From this point on, any user changes should be done simultaneously
        # to the whitelist set and user db, unless the whitelist is empty (all users allowed).

        db.commit()
        
        for user in new_users:
            yield gen.maybe_future(self.authenticator.add_user(user))
        db.commit()
    
    @gen.coroutine
    def init_spawners(self):
        db = self.db
        
        user_summaries = ['']
        def _user_summary(user):
            parts = ['{0: >8}'.format(user.name)]
            if user.admin:
                parts.append('admin')
            if user.server:
                parts.append('running at %s' % user.server)
            return ' '.join(parts)
            
        @gen.coroutine
        def user_stopped(user):
            status = yield user.spawner.poll()
            self.log.warn("User %s server stopped with exit code: %s",
                user.name, status,
            )
            yield self.proxy.delete_user(user)
            yield user.stop()
        
        for user in db.query(orm.User):
            if not user.state:
                # without spawner state, server isn't valid
                user.server = None
                user_summaries.append(_user_summary(user))
                continue
            self.log.debug("Loading state for %s from db", user.name)
            user.spawner = spawner = self.spawner_class(
                user=user, hub=self.hub, config=self.config, db=self.db,
            )
            status = yield spawner.poll()
            if status is None:
                self.log.info("%s still running", user.name)
                spawner.add_poll_callback(user_stopped, user)
                spawner.start_polling()
            else:
                # user not running. This is expected if server is None,
                # but indicates the user's server died while the Hub wasn't running
                # if user.server is defined.
                log = self.log.warn if user.server else self.log.debug
                log("%s not running.", user.name)
                user.server = None

            user_summaries.append(_user_summary(user))

        self.log.debug("Loaded users: %s", '\n'.join(user_summaries))
        db.commit()

    def init_proxy(self):
        """Load the Proxy config into the database"""
        self.proxy = self.db.query(orm.Proxy).first()
        if self.proxy is None:
            self.proxy = orm.Proxy(
                public_server=orm.Server(),
                api_server=orm.Server(),
            )
            self.db.add(self.proxy)
            self.db.commit()
        self.proxy.auth_token = self.proxy_auth_token # not persisted
        self.proxy.log = self.log
        self.proxy.public_server.ip = self.ip
        self.proxy.public_server.port = self.port
        self.proxy.api_server.ip = self.proxy_api_ip
        self.proxy.api_server.port = self.proxy_api_port
        self.proxy.api_server.base_url = '/api/routes/'
        self.db.commit()
    
    @gen.coroutine
    def start_proxy(self):
        """Actually start the configurable-http-proxy"""
        # check for proxy
        if self.proxy.public_server.is_up() or self.proxy.api_server.is_up():
            # check for *authenticated* access to the proxy (auth token can change)
            try:
                yield self.proxy.get_routes()
            except (HTTPError, OSError, socket.error) as e:
                if isinstance(e, HTTPError) and e.code == 403:
                    msg = "Did CONFIGPROXY_AUTH_TOKEN change?"
                else:
                    msg = "Is something else using %s?" % self.proxy.public_server.bind_url
                self.log.error("Proxy appears to be running at %s, but I can't access it (%s)\n%s",
                    self.proxy.public_server.bind_url, e, msg)
                self.exit(1)
                return
            else:
                self.log.info("Proxy already running at: %s", self.proxy.public_server.bind_url)
            self.proxy_process = None
            return

        env = os.environ.copy()
        env['CONFIGPROXY_AUTH_TOKEN'] = self.proxy.auth_token
        cmd = self.proxy_cmd + [
            '--ip', self.proxy.public_server.ip,
            '--port', str(self.proxy.public_server.port),
            '--api-ip', self.proxy.api_server.ip,
            '--api-port', str(self.proxy.api_server.port),
            '--default-target', self.hub.server.host,
        ]
        if self.debug_proxy:
            cmd.extend(['--log-level', 'debug'])
        if self.ssl_key:
            cmd.extend(['--ssl-key', self.ssl_key])
        if self.ssl_cert:
            cmd.extend(['--ssl-cert', self.ssl_cert])
        self.log.info("Starting proxy @ %s", self.proxy.public_server.bind_url)
        self.log.debug("Proxy cmd: %s", cmd)
        try:
            self.proxy_process = Popen(cmd, env=env)
        except FileNotFoundError as e:
            self.log.error(
                "Failed to find proxy %r\n"
                "The proxy can be installed with `npm install -g configurable-http-proxy`"
                 % self.proxy_cmd
            )
            self.exit(1)
        def _check():
            status = self.proxy_process.poll()
            if status is not None:
                e = RuntimeError("Proxy failed to start with exit code %i" % status)
                # py2-compatible `raise e from None`
                e.__cause__ = None
                raise e
        
        for server in (self.proxy.public_server, self.proxy.api_server):
            for i in range(10):
                _check()
                try:
                    yield server.wait_up(1)
                except TimeoutError:
                    continue
                else:
                    break
            yield server.wait_up(1)
        self.log.debug("Proxy started and appears to be up")
    
    @gen.coroutine
    def check_proxy(self):
        if self.proxy_process.poll() is None:
            return
        self.log.error("Proxy stopped with exit code %r",
            'unknown' if self.proxy_process is None else self.proxy_process.poll()
        )
        yield self.start_proxy()
        self.log.info("Setting up routes on new proxy")
        yield self.proxy.add_all_users()
        self.log.info("New proxy back up, and good to go")
    
    def init_tornado_settings(self):
        """Set up the tornado settings dict."""
        base_url = self.hub.server.base_url
        jinja_env = Environment(
            loader=FileSystemLoader(self.template_paths),
            **self.jinja_environment_options
        )
        
        login_url = self.authenticator.login_url(base_url)
        logout_url = self.authenticator.logout_url(base_url)
        
        # if running from git, disable caching of require.js
        # otherwise cache based on server start time
        parent = os.path.dirname(os.path.dirname(jupyterhub.__file__))
        if os.path.isdir(os.path.join(parent, '.git')):
            version_hash = ''
        else:
            version_hash=datetime.now().strftime("%Y%m%d%H%M%S"),
        
        settings = dict(
            log_function=log_request,
            config=self.config,
            log=self.log,
            db=self.db,
            proxy=self.proxy,
            hub=self.hub,
            admin_users=self.authenticator.admin_users,
            admin_access=self.admin_access,
            authenticator=self.authenticator,
            spawner_class=self.spawner_class,
            base_url=self.base_url,
            cookie_secret=self.cookie_secret,
            cookie_max_age_days=self.cookie_max_age_days,
            login_url=login_url,
            logout_url=logout_url,
            static_path=os.path.join(self.data_files_path, 'static'),
            static_url_prefix=url_path_join(self.hub.server.base_url, 'static/'),
            static_handler_class=CacheControlStaticFilesHandler,
            template_path=self.template_paths,
            jinja2_env=jinja_env,
            version_hash=version_hash,
        )
        # allow configured settings to have priority
        settings.update(self.tornado_settings)
        self.tornado_settings = settings
    
    def init_tornado_application(self):
        """Instantiate the tornado Application object"""
        self.tornado_application = web.Application(self.handlers, **self.tornado_settings)
    
    def write_pid_file(self):
        pid = os.getpid()
        if self.pid_file:
            self.log.debug("Writing PID %i to %s", pid, self.pid_file)
            with open(self.pid_file, 'w') as f:
                f.write('%i' % pid)
    
    @gen.coroutine
    @catch_config_error
    def initialize(self, *args, **kwargs):
        super().initialize(*args, **kwargs)
        if self.generate_config or self.subapp:
            return
        self.load_config_file(self.config_file)
        self.init_logging()
        if 'JupyterHubApp' in self.config:
            self.log.warn("Use JupyterHub in config, not JupyterHubApp. Outdated config:\n%s",
                '\n'.join('JupyterHubApp.{key} = {value!r}'.format(key=key, value=value)
                    for key, value in self.config.JupyterHubApp.items()
                )
            )
            cfg = self.config.copy()
            cfg.JupyterHub.merge(cfg.JupyterHubApp)
            self.update_config(cfg)
        self.write_pid_file()
        self.init_ports()
        self.init_secrets()
        self.init_db()
        self.init_hub()
        self.init_proxy()
        yield self.init_users()
        yield self.init_spawners()
        self.init_handlers()
        self.init_tornado_settings()
        self.init_tornado_application()
    
    @gen.coroutine
    def cleanup(self):
        """Shutdown our various subprocesses and cleanup runtime files."""
        
        futures = []
        if self.cleanup_servers:
            self.log.info("Cleaning up single-user servers...")
            # request (async) process termination
            for user in self.db.query(orm.User):
                if user.spawner is not None:
                    futures.append(user.stop())
        else:
            self.log.info("Leaving single-user servers running")
        
        # clean up proxy while SUS are shutting down
        if self.cleanup_proxy:
            if self.proxy_process:
                self.log.info("Cleaning up proxy[%i]...", self.proxy_process.pid)
                if self.proxy_process.poll() is None:
                    try:
                        self.proxy_process.terminate()
                    except Exception as e:
                        self.log.error("Failed to terminate proxy process: %s", e)
            else:
                self.log.info("I didn't start the proxy, I can't clean it up")
        else:
            self.log.info("Leaving proxy running")
        
        
        # wait for the requests to stop finish:
        for f in futures:
            try:
                yield f
            except Exception as e:
                self.log.error("Failed to stop user: %s", e)
        
        self.db.commit()
        
        if self.pid_file and os.path.exists(self.pid_file):
            self.log.info("Cleaning up PID file %s", self.pid_file)
            os.remove(self.pid_file)
        
        # finally stop the loop once we are all cleaned up
        self.log.info("...done")
    
    def write_config_file(self):
        """Write our default config to a .py config file"""
        if os.path.exists(self.config_file) and not self.answer_yes:
            answer = ''
            def ask():
                prompt = "Overwrite %s with default config? [y/N]" % self.config_file
                try:
                    return input(prompt).lower() or 'n'
                except KeyboardInterrupt:
                    print('') # empty line
                    return 'n'
            answer = ask()
            while not answer.startswith(('y', 'n')):
                print("Please answer 'yes' or 'no'")
                answer = ask()
            if answer.startswith('n'):
                return
        
        config_text = self.generate_config_file()
        if isinstance(config_text, bytes):
            config_text = config_text.decode('utf8')
        print("Writing default config to: %s" % self.config_file)
        with open(self.config_file, mode='w') as f:
            f.write(config_text)
    
    @gen.coroutine
    def update_last_activity(self):
        """Update User.last_activity timestamps from the proxy"""
        routes = yield self.proxy.get_routes()
        for prefix, route in routes.items():
            if 'user' not in route:
                # not a user route, ignore it
                continue
            user = orm.User.find(self.db, route['user'])
            if user is None:
                self.log.warn("Found no user for route: %s", route)
                continue
            try:
                dt = datetime.strptime(route['last_activity'], ISO8601_ms)
            except Exception:
                dt = datetime.strptime(route['last_activity'], ISO8601_s)
            user.last_activity = max(user.last_activity, dt)

        self.db.commit()
        yield self.proxy.check_routes(routes)
    
    @gen.coroutine
    def start(self):
        """Start the whole thing"""
        self.io_loop = loop = IOLoop.current()
        
        if self.subapp:
            self.subapp.start()
            loop.stop()
            return
        
        if self.generate_config:
            self.write_config_file()
            loop.stop()
            return
        
        # start the webserver
        self.http_server = tornado.httpserver.HTTPServer(self.tornado_application, xheaders=True)
        try:
            self.http_server.listen(self.hub_port, address=self.hub_ip)
        except Exception:
            self.log.error("Failed to bind hub to %s", self.hub.server.bind_url)
            raise
        else:
            self.log.info("Hub API listening on %s", self.hub.server.bind_url)
        
        # start the proxy
        try:
            yield self.start_proxy()
        except Exception as e:
            self.log.critical("Failed to start proxy", exc_info=True)
            self.exit(1)
            return
        
        loop.add_callback(self.proxy.add_all_users)
        
        if self.proxy_process:
            # only check / restart the proxy if we started it in the first place.
            # this means a restarted Hub cannot restart a Proxy that its
            # predecessor started.
            pc = PeriodicCallback(self.check_proxy, 1e3 * self.proxy_check_interval)
            pc.start()
        
        if self.last_activity_interval:
            pc = PeriodicCallback(self.update_last_activity, 1e3 * self.last_activity_interval)
            pc.start()

        self.log.info("JupyterHub is now running at %s", self.proxy.public_server.url)
        # register cleanup on both TERM and INT
        atexit.register(self.atexit)
        self.init_signal()

    def init_signal(self):
        signal.signal(signal.SIGTERM, self.sigterm)
    
    def sigterm(self, signum, frame):
        self.log.critical("Received SIGTERM, shutting down")
        self.io_loop.stop()
        self.atexit()
    
    _atexit_ran = False
    def atexit(self):
        """atexit callback"""
        if self._atexit_ran:
            return
        self._atexit_ran = True
        # run the cleanup step (in a new loop, because the interrupted one is unclean)
        IOLoop.clear_current()
        loop = IOLoop()
        loop.make_current()
        loop.run_sync(self.cleanup)
        
    
    def stop(self):
        if not self.io_loop:
            return
        if self.http_server:
            if self.io_loop._running:
                self.io_loop.add_callback(self.http_server.stop)
            else:
                self.http_server.stop()
        self.io_loop.add_callback(self.io_loop.stop)
    
    @gen.coroutine
    def launch_instance_async(self, argv=None):
        try:
            yield self.initialize(argv)
            yield self.start()
        except Exception as e:
            self.log.exception("")
            self.exit(1)
    
    @classmethod
    def launch_instance(cls, argv=None):
        self = cls.instance()
        loop = IOLoop.current()
        loop.add_callback(self.launch_instance_async, argv)
        try:
            loop.start()
        except KeyboardInterrupt:
            print("\nInterrupted")
Exemple #14
0
class GitLabOAuthenticator(OAuthenticator):

    login_service = "GitLab"

    client_id_env = 'GITLAB_CLIENT_ID'
    client_secret_env = 'GITLAB_CLIENT_SECRET'
    login_handler = GitLabLoginHandler

    gitlab_group_whitelist = Set(
        config=True,
        help="Automatically whitelist members of selected groups",
    )

    @gen.coroutine
    def authenticate(self, handler, data=None):
        code = handler.get_argument("code")
        # TODO: Configure the curl_httpclient for tornado
        http_client = AsyncHTTPClient()

        # Exchange the OAuth code for a GitLab Access Token
        #
        # See: https://github.com/gitlabhq/gitlabhq/blob/master/doc/api/oauth2.md

        # GitLab specifies a POST request yet requires URL parameters
        params = dict(
            client_id=self.client_id,
            client_secret=self.client_secret,
            code=code,
            grant_type="authorization_code",
            redirect_uri=self.get_callback_url(handler),
        )

        validate_server_cert = self.validate_server_cert

        url = url_concat("%s/oauth/token" % GITLAB_HOST, params)

        req = HTTPRequest(
            url,
            method="POST",
            headers={"Accept": "application/json"},
            validate_cert=validate_server_cert,
            body=''  # Body is required for a POST...
        )

        resp = yield http_client.fetch(req)
        resp_json = json.loads(resp.body.decode('utf8', 'replace'))

        access_token = resp_json['access_token']

        # Determine who the logged in user is
        req = HTTPRequest("%s/user" % GITLAB_API,
                          method="GET",
                          validate_cert=validate_server_cert,
                          headers=_api_headers(access_token))
        resp = yield http_client.fetch(req)
        resp_json = json.loads(resp.body.decode('utf8', 'replace'))

        username = resp_json["username"]
        user_id = resp_json["id"]
        is_admin = resp_json.get("is_admin", False)

        # Check if user is a member of any whitelisted organizations.
        # This check is performed here, as it requires `access_token`.
        if self.gitlab_group_whitelist:
            user_in_group = yield self._check_group_whitelist(
                username, user_id, is_admin, access_token)
            if not user_in_group:
                self.log.warning("%s not in group whitelist", username)
                return None
        return {
            'name': username,
            'auth_state': {
                'access_token': access_token,
                'gitlab_user': resp_json,
            }
        }

    @gen.coroutine
    def _check_group_whitelist(self, username, user_id, is_admin,
                               access_token):
        http_client = AsyncHTTPClient()
        headers = _api_headers(access_token)
        # Check if we are a member of each group in the whitelist
        for group in map(url_escape, self.gitlab_group_whitelist):
            url = "%s/groups/%s/members/%d" % (GITLAB_API, group, user_id)
            req = HTTPRequest(url, method="GET", headers=headers)
            resp = yield http_client.fetch(req, raise_error=False)
            if resp.code == 200:
                return True  # user _is_ in group
        return False
class BitbucketOAuthenticator(OAuthenticator):

    _deprecated_oauth_aliases = {
        "team_whitelist": ("allowed_teams", "0.12.0"),
        **OAuthenticator._deprecated_oauth_aliases,
    }

    login_service = "Bitbucket"
    client_id_env = 'BITBUCKET_CLIENT_ID'
    client_secret_env = 'BITBUCKET_CLIENT_SECRET'

    @default("authorize_url")
    def _authorize_url_default(self):
        return "https://bitbucket.org/site/oauth2/authorize"

    @default("token_url")
    def _token_url_default(self):
        return "https://bitbucket.org/site/oauth2/access_token"

    team_whitelist = Set(
        help="Deprecated, use `BitbucketOAuthenticator.allowed_teams`",
        config=True,
    )

    allowed_teams = Set(config=True,
                        help="Automatically allow members of selected teams")

    headers = {
        "Accept": "application/json",
        "User-Agent": "JupyterHub",
        "Authorization": "Bearer {}",
    }

    async def authenticate(self, handler, data=None):
        code = handler.get_argument("code")

        params = dict(
            client_id=self.client_id,
            client_secret=self.client_secret,
            grant_type="authorization_code",
            code=code,
            redirect_uri=self.get_callback_url(handler),
        )

        url = url_concat("https://bitbucket.org/site/oauth2/access_token",
                         params)

        bb_header = {
            "Content-Type": "application/x-www-form-urlencoded;charset=utf-8"
        }
        req = HTTPRequest(
            url,
            method="POST",
            auth_username=self.client_id,
            auth_password=self.client_secret,
            body=urllib.parse.urlencode(params).encode('utf-8'),
            headers=bb_header,
        )

        resp_json = await self.fetch(req)

        access_token = resp_json['access_token']

        # Determine who the logged in user is
        req = HTTPRequest(
            "https://api.bitbucket.org/2.0/user",
            method="GET",
            headers=_api_headers(access_token),
        )
        resp_json = await self.fetch(req)

        username = resp_json["username"]

        # Check if user is a member of any allowed teams.
        # This check is performed here, as the check requires `access_token`.
        if self.allowed_teams:
            user_in_team = await self._check_membership_allowed_teams(
                username, access_token)
            if not user_in_team:
                self.log.warning("%s not in team allowed list of users",
                                 username)
                return None

        return {
            'name': username,
            'auth_state': {
                'access_token': access_token,
                'bitbucket_user': resp_json
            },
        }

    async def _check_membership_allowed_teams(self, username, access_token):

        headers = _api_headers(access_token)
        # We verify the team membership by calling teams endpoint.
        next_page = url_concat("https://api.bitbucket.org/2.0/teams",
                               {'role': 'member'})
        while next_page:
            req = HTTPRequest(next_page, method="GET", headers=headers)
            resp_json = await self.fetch(req)
            next_page = resp_json.get('next', None)

            user_teams = set(
                [entry["username"] for entry in resp_json["values"]])
            # check if any of the organizations seen thus far are in the allowed list
            if len(self.allowed_teams & user_teams) > 0:
                return True
        return False
Exemple #16
0
class Widget(LoggingConfigurable):
    #-------------------------------------------------------------------------
    # Class attributes
    #-------------------------------------------------------------------------
    _widget_construction_callback = None
    widgets = {}
    widget_types = {}

    @staticmethod
    def on_widget_constructed(callback):
        """Registers a callback to be called when a widget is constructed.

        The callback must have the following signature:
        callback(widget)"""
        Widget._widget_construction_callback = callback

    @staticmethod
    def _call_widget_constructed(widget):
        """Static method, called when a widget is constructed."""
        if Widget._widget_construction_callback is not None and callable(
                Widget._widget_construction_callback):
            Widget._widget_construction_callback(widget)

    @staticmethod
    def handle_comm_opened(comm, msg):
        """Static method, called when a widget is constructed."""
        widget_class = import_item(str(msg['content']['data']['widget_class']))
        widget = widget_class(comm=comm)

    #-------------------------------------------------------------------------
    # Traits
    #-------------------------------------------------------------------------
    _model_module = Unicode(None,
                            allow_none=True,
                            help="""A requirejs module name
        in which to find _model_name. If empty, look in the global registry."""
                            )
    _model_name = Unicode('WidgetModel',
                          help="""Name of the backbone model 
        registered in the front-end to create and sync this widget with.""")
    _view_module = Unicode(
        help="""A requirejs module in which to find _view_name.
        If empty, look in the global registry.""",
        sync=True)
    _view_name = Unicode(None,
                         allow_none=True,
                         help="""Default view registered in the front-end
        to use to represent the widget.""",
                         sync=True)
    comm = Instance('ipykernel.comm.Comm', allow_none=True)

    msg_throttle = Int(3,
                       sync=True,
                       help="""Maximum number of msgs the 
        front-end can send before receiving an idle msg from the back-end.""")

    version = Int(0, sync=True, help="""Widget's version""")
    keys = List()

    def _keys_default(self):
        return [name for name in self.traits(sync=True)]

    _property_lock = Dict()
    _send_state_lock = Int(0)
    _states_to_send = Set()
    _display_callbacks = Instance(CallbackDispatcher, ())
    _msg_callbacks = Instance(CallbackDispatcher, ())

    #-------------------------------------------------------------------------
    # (Con/de)structor
    #-------------------------------------------------------------------------
    def __init__(self, **kwargs):
        """Public constructor"""
        self._model_id = kwargs.pop('model_id', None)
        super(Widget, self).__init__(**kwargs)

        Widget._call_widget_constructed(self)
        self.open()

    def __del__(self):
        """Object disposal"""
        self.close()

    #-------------------------------------------------------------------------
    # Properties
    #-------------------------------------------------------------------------

    def open(self):
        """Open a comm to the frontend if one isn't already open."""
        if self.comm is None:
            args = dict(target_name='ipython.widget',
                        data={
                            'model_name': self._model_name,
                            'model_module': self._model_module
                        })
            if self._model_id is not None:
                args['comm_id'] = self._model_id
            self.comm = Comm(**args)

    def _comm_changed(self, name, new):
        """Called when the comm is changed."""
        if new is None:
            return
        self._model_id = self.model_id

        self.comm.on_msg(self._handle_msg)
        Widget.widgets[self.model_id] = self

        # first update
        self.send_state()

    @property
    def model_id(self):
        """Gets the model id of this widget.

        If a Comm doesn't exist yet, a Comm will be created automagically."""
        return self.comm.comm_id

    #-------------------------------------------------------------------------
    # Methods
    #-------------------------------------------------------------------------

    def close(self):
        """Close method.

        Closes the underlying comm.
        When the comm is closed, all of the widget views are automatically
        removed from the front-end."""
        if self.comm is not None:
            Widget.widgets.pop(self.model_id, None)
            self.comm.close()
            self.comm = None

    def send_state(self, key=None):
        """Sends the widget state, or a piece of it, to the front-end.

        Parameters
        ----------
        key : unicode, or iterable (optional)
            A single property's name or iterable of property names to sync with the front-end.
        """
        state, buffer_keys, buffers = self.get_state(key=key)
        msg = {"method": "update", "state": state}
        if buffer_keys:
            msg['buffers'] = buffer_keys
        self._send(msg, buffers=buffers)

    def get_state(self, key=None):
        """Gets the widget state, or a piece of it.

        Parameters
        ----------
        key : unicode or iterable (optional)
            A single property's name or iterable of property names to get.

        Returns
        -------
        state : dict of states
        buffer_keys : list of strings
            the values that are stored in buffers
        buffers : list of binary memoryviews
            values to transmit in binary
        metadata : dict
            metadata for each field: {key: metadata}
        """
        if key is None:
            keys = self.keys
        elif isinstance(key, string_types):
            keys = [key]
        elif isinstance(key, collections.Iterable):
            keys = key
        else:
            raise ValueError(
                "key must be a string, an iterable of keys, or None")
        state = {}
        buffers = []
        buffer_keys = []
        for k in keys:
            f = self.trait_metadata(k, 'to_json', self._trait_to_json)
            value = getattr(self, k)
            serialized = f(value)
            if isinstance(serialized, memoryview):
                buffers.append(serialized)
                buffer_keys.append(k)
            else:
                state[k] = serialized
        return state, buffer_keys, buffers

    def set_state(self, sync_data):
        """Called when a state is received from the front-end."""
        # The order of these context managers is important. Properties must
        # be locked when the hold_trait_notification context manager is
        # released and notifications are fired.
        with self._lock_property(**sync_data), self.hold_trait_notifications():
            for name in sync_data:
                if name in self.keys:
                    from_json = self.trait_metadata(name, 'from_json',
                                                    self._trait_from_json)
                    setattr(self, name, from_json(sync_data[name]))

    def send(self, content, buffers=None):
        """Sends a custom msg to the widget model in the front-end.

        Parameters
        ----------
        content : dict
            Content of the message to send.
        buffers : list of binary buffers
            Binary buffers to send with message
        """
        self._send({"method": "custom", "content": content}, buffers=buffers)

    def on_msg(self, callback, remove=False):
        """(Un)Register a custom msg receive callback.

        Parameters
        ----------
        callback: callable
            callback will be passed three arguments when a message arrives::
            
                callback(widget, content, buffers)
            
        remove: bool
            True if the callback should be unregistered."""
        self._msg_callbacks.register_callback(callback, remove=remove)

    def on_displayed(self, callback, remove=False):
        """(Un)Register a widget displayed callback.

        Parameters
        ----------
        callback: method handler
            Must have a signature of::
            
                callback(widget, **kwargs)
            
            kwargs from display are passed through without modification.
        remove: bool
            True if the callback should be unregistered."""
        self._display_callbacks.register_callback(callback, remove=remove)

    def add_trait(self, traitname, trait):
        """Dynamically add a trait attribute to the Widget."""
        super(Widget, self).add_trait(traitname, trait)
        if trait.get_metadata('sync'):
            self.keys.append(traitname)
            self.send_state(traitname)

    #-------------------------------------------------------------------------
    # Support methods
    #-------------------------------------------------------------------------
    @contextmanager
    def _lock_property(self, **properties):
        """Lock a property-value pair.

        The value should be the JSON state of the property.

        NOTE: This, in addition to the single lock for all state changes, is
        flawed.  In the future we may want to look into buffering state changes 
        back to the front-end."""
        self._property_lock = properties
        try:
            yield
        finally:
            self._property_lock = {}

    @contextmanager
    def hold_sync(self):
        """Hold syncing any state until the context manager is released"""
        # We increment a value so that this can be nested.  Syncing will happen when
        # all levels have been released.
        self._send_state_lock += 1
        try:
            yield
        finally:
            self._send_state_lock -= 1
            if self._send_state_lock == 0:
                self.send_state(self._states_to_send)
                self._states_to_send.clear()

    def _should_send_property(self, key, value):
        """Check the property lock (property_lock)"""
        to_json = self.trait_metadata(key, 'to_json', self._trait_to_json)
        if (key in self._property_lock
                and to_json(value) == self._property_lock[key]):
            return False
        elif self._send_state_lock > 0:
            self._states_to_send.add(key)
            return False
        else:
            return True

    # Event handlers
    @_show_traceback
    def _handle_msg(self, msg):
        """Called when a msg is received from the front-end"""
        data = msg['content']['data']
        method = data['method']

        # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one.
        if method == 'backbone':
            if 'sync_data' in data:
                # get binary buffers too
                sync_data = data['sync_data']
                for i, k in enumerate(data.get('buffer_keys', [])):
                    sync_data[k] = msg['buffers'][i]
                self.set_state(sync_data)  # handles all methods

        # Handle a state request.
        elif method == 'request_state':
            self.send_state()

        # Handle a custom msg from the front-end.
        elif method == 'custom':
            if 'content' in data:
                self._handle_custom_msg(data['content'], msg['buffers'])

        # Catch remainder.
        else:
            self.log.error(
                'Unknown front-end to back-end widget msg with method "%s"' %
                method)

    def _handle_custom_msg(self, content, buffers):
        """Called when a custom msg is received."""
        self._msg_callbacks(self, content, buffers)

    def _notify_trait(self, name, old_value, new_value):
        """Called when a property has been changed."""
        # Trigger default traitlet callback machinery.  This allows any user
        # registered validation to be processed prior to allowing the widget
        # machinery to handle the state.
        LoggingConfigurable._notify_trait(self, name, old_value, new_value)

        # Send the state after the user registered callbacks for trait changes
        # have all fired (allows for user to validate values).
        if self.comm is not None and name in self.keys:
            # Make sure this isn't information that the front-end just sent us.
            if self._should_send_property(name, new_value):
                # Send new state to front-end
                self.send_state(key=name)

    def _handle_displayed(self, **kwargs):
        """Called when a view has been displayed for this widget instance"""
        self._display_callbacks(self, **kwargs)

    def _trait_to_json(self, x):
        """Convert a trait value to json."""
        return x

    def _trait_from_json(self, x):
        """Convert json values to objects."""
        return x

    def _ipython_display_(self, **kwargs):
        """Called when `IPython.display.display` is called on the widget."""
        # Show view.
        if self._view_name is not None:
            self._send({"method": "display"})
            self._handle_displayed(**kwargs)

    def _send(self, msg, buffers=None):
        """Sends a message to the model in the front-end."""
        self.comm.send(data=msg, buffers=buffers)
Exemple #17
0
class Widget(LoggingHasTraits):
    #-------------------------------------------------------------------------
    # Class attributes
    #-------------------------------------------------------------------------
    _widget_construction_callback = None

    # widgets is a dictionary of all active widget objects
    widgets = {}

    # widget_types is a registry of widgets by module, version, and name:
    widget_types = WidgetRegistry()

    @staticmethod
    def on_widget_constructed(callback):
        """Registers a callback to be called when a widget is constructed.

        The callback must have the following signature:
        callback(widget)"""
        Widget._widget_construction_callback = callback

    @staticmethod
    def _call_widget_constructed(widget):
        """Static method, called when a widget is constructed."""
        if Widget._widget_construction_callback is not None and callable(
                Widget._widget_construction_callback):
            Widget._widget_construction_callback(widget)

    @staticmethod
    def handle_comm_opened(comm, msg):
        """Static method, called when a widget is constructed."""
        version = msg.get('metadata', {}).get('version', '')
        if version.split('.')[0] != PROTOCOL_VERSION_MAJOR:
            raise ValueError(
                "Incompatible widget protocol versions: received version %r, expected version %r"
                % (version, __protocol_version__))
        data = msg['content']['data']
        state = data['state']

        # Find the widget class to instantiate in the registered widgets
        widget_class = Widget.widget_types.get(state['_model_module'],
                                               state['_model_module_version'],
                                               state['_model_name'],
                                               state['_view_module'],
                                               state['_view_module_version'],
                                               state['_view_name'])
        widget = widget_class(comm=comm)
        if 'buffer_paths' in data:
            _put_buffers(state, data['buffer_paths'], msg['buffers'])
        widget.set_state(state)

    @staticmethod
    def get_manager_state(drop_defaults=False, widgets=None):
        """Returns the full state for a widget manager for embedding

        :param drop_defaults: when True, it will not include default value
        :param widgets: list with widgets to include in the state (or all widgets when None)
        :return:
        """
        state = {}
        if widgets is None:
            widgets = Widget.widgets.values()
        for widget in widgets:
            state[widget.model_id] = widget._get_embed_state(
                drop_defaults=drop_defaults)
        return {'version_major': 2, 'version_minor': 0, 'state': state}

    def _get_embed_state(self, drop_defaults=False):
        state = {
            'model_name': self._model_name,
            'model_module': self._model_module,
            'model_module_version': self._model_module_version
        }
        model_state, buffer_paths, buffers = _remove_buffers(
            self.get_state(drop_defaults=drop_defaults))
        state['state'] = model_state
        if len(buffers) > 0:
            state['buffers'] = [{
                'encoding': 'base64',
                'path': p,
                'data': standard_b64encode(d).decode('ascii')
            } for p, d in zip(buffer_paths, buffers)]
        return state

    def get_view_spec(self):
        return dict(version_major=2, version_minor=0, model_id=self._model_id)

    #-------------------------------------------------------------------------
    # Traits
    #-------------------------------------------------------------------------
    _model_name = Unicode('WidgetModel',
                          help="Name of the model.",
                          read_only=True).tag(sync=True)
    _model_module = Unicode('@jupyter-widgets/base',
                            help="The namespace for the model.",
                            read_only=True).tag(sync=True)
    _model_module_version = Unicode(
        __jupyter_widgets_base_version__,
        help="A semver requirement for namespace version containing the model.",
        read_only=True).tag(sync=True)
    _view_name = Unicode(None, allow_none=True,
                         help="Name of the view.").tag(sync=True)
    _view_module = Unicode(None,
                           allow_none=True,
                           help="The namespace for the view.").tag(sync=True)
    _view_module_version = Unicode(
        '',
        help=
        "A semver requirement for the namespace version containing the view."
    ).tag(sync=True)

    _view_count = Int(
        None,
        allow_none=True,
        help=
        "EXPERIMENTAL: The number of views of the model displayed in the frontend. This attribute is experimental and may change or be removed in the future. None signifies that views will not be tracked. Set this to 0 to start tracking view creation/deletion."
    ).tag(sync=True)
    comm = Instance('ipykernel.comm.Comm', allow_none=True)

    keys = List(help="The traits which are synced.")

    @default('keys')
    def _default_keys(self):
        return [name for name in self.traits(sync=True)]

    _property_lock = Dict()
    _holding_sync = False
    _states_to_send = Set()
    _display_callbacks = Instance(CallbackDispatcher, ())
    _msg_callbacks = Instance(CallbackDispatcher, ())

    #-------------------------------------------------------------------------
    # (Con/de)structor
    #-------------------------------------------------------------------------
    def __init__(self, **kwargs):
        """Public constructor"""
        self._model_id = kwargs.pop('model_id', None)
        super(Widget, self).__init__(**kwargs)

        Widget._call_widget_constructed(self)
        self.open()

    def __del__(self):
        """Object disposal"""
        self.close()

    #-------------------------------------------------------------------------
    # Properties
    #-------------------------------------------------------------------------

    def open(self):
        """Open a comm to the frontend if one isn't already open."""
        if self.comm is None:
            state, buffer_paths, buffers = _remove_buffers(self.get_state())

            args = dict(target_name='jupyter.widget',
                        data={
                            'state': state,
                            'buffer_paths': buffer_paths
                        },
                        buffers=buffers,
                        metadata={'version': __protocol_version__})
            if self._model_id is not None:
                args['comm_id'] = self._model_id

            self.comm = Comm(**args)

    @observe('comm')
    def _comm_changed(self, change):
        """Called when the comm is changed."""
        if change['new'] is None:
            return
        self._model_id = self.model_id

        self.comm.on_msg(self._handle_msg)
        Widget.widgets[self.model_id] = self

    @property
    def model_id(self):
        """Gets the model id of this widget.

        If a Comm doesn't exist yet, a Comm will be created automagically."""
        return self.comm.comm_id

    #-------------------------------------------------------------------------
    # Methods
    #-------------------------------------------------------------------------

    def close(self):
        """Close method.

        Closes the underlying comm.
        When the comm is closed, all of the widget views are automatically
        removed from the front-end."""
        if self.comm is not None:
            Widget.widgets.pop(self.model_id, None)
            self.comm.close()
            self.comm = None
            self._ipython_display_ = None

    def send_state(self, key=None):
        """Sends the widget state, or a piece of it, to the front-end.

        Parameters
        ----------
        key : unicode, or iterable (optional)
            A single property's name or iterable of property names to sync with the front-end.
        """
        state = self.get_state(key=key)
        state, buffer_paths, buffers = _remove_buffers(state)
        msg = {
            'method': 'update',
            'state': state,
            'buffer_paths': buffer_paths
        }
        self._send(msg, buffers=buffers)

    def get_state(self, key=None, drop_defaults=False):
        """Gets the widget state, or a piece of it.

        Parameters
        ----------
        key : unicode or iterable (optional)
            A single property's name or iterable of property names to get.

        Returns
        -------
        state : dict of states
        metadata : dict
            metadata for each field: {key: metadata}
        """
        if key is None:
            keys = self.keys
        elif isinstance(key, string_types):
            keys = [key]
        elif isinstance(key, collections.Iterable):
            keys = key
        else:
            raise ValueError(
                "key must be a string, an iterable of keys, or None")
        state = {}
        traits = self.traits()
        for k in keys:
            to_json = self.trait_metadata(k, 'to_json', self._trait_to_json)
            value = to_json(getattr(self, k), self)
            if not PY3 and isinstance(traits[k], Bytes) and isinstance(
                    value, bytes):
                value = memoryview(value)
            if not drop_defaults or not self._compare(value,
                                                      traits[k].default_value):
                state[k] = value
        return state

    def _is_numpy(self, x):
        return x.__class__.__name__ == 'ndarray' and x.__class__.__module__ == 'numpy'

    def _compare(self, a, b):
        if self._is_numpy(a) or self._is_numpy(b):
            import numpy as np
            return np.array_equal(a, b)
        else:
            return a == b

    def set_state(self, sync_data):
        """Called when a state is received from the front-end."""
        # The order of these context managers is important. Properties must
        # be locked when the hold_trait_notification context manager is
        # released and notifications are fired.
        with self._lock_property(**sync_data), self.hold_trait_notifications():
            for name in sync_data:
                if name in self.keys:
                    from_json = self.trait_metadata(name, 'from_json',
                                                    self._trait_from_json)
                    self.set_trait(name, from_json(sync_data[name], self))

    def send(self, content, buffers=None):
        """Sends a custom msg to the widget model in the front-end.

        Parameters
        ----------
        content : dict
            Content of the message to send.
        buffers : list of binary buffers
            Binary buffers to send with message
        """
        self._send({"method": "custom", "content": content}, buffers=buffers)

    def on_msg(self, callback, remove=False):
        """(Un)Register a custom msg receive callback.

        Parameters
        ----------
        callback: callable
            callback will be passed three arguments when a message arrives::

                callback(widget, content, buffers)

        remove: bool
            True if the callback should be unregistered."""
        self._msg_callbacks.register_callback(callback, remove=remove)

    def on_displayed(self, callback, remove=False):
        """(Un)Register a widget displayed callback.

        Parameters
        ----------
        callback: method handler
            Must have a signature of::

                callback(widget, **kwargs)

            kwargs from display are passed through without modification.
        remove: bool
            True if the callback should be unregistered."""
        self.comm.comm_log.write("widget::on_displayed !!!!!!!!!!\n")
        self.comm.comm_log.flush()
        self._display_callbacks.register_callback(callback, remove=remove)

    def add_traits(self, **traits):
        """Dynamically add trait attributes to the Widget."""
        super(Widget, self).add_traits(**traits)
        for name, trait in traits.items():
            if trait.get_metadata('sync'):
                self.keys.append(name)
                self.send_state(name)

    def notify_change(self, change):
        """Called when a property has changed."""
        # Send the state to the frontend before the user-registered callbacks
        # are called.
        name = change['name']
        if self.comm is not None and self.comm.kernel is not None:
            # Make sure this isn't information that the front-end just sent us.
            if name in self.keys and self._should_send_property(
                    name, change['new']):
                # Send new state to front-end
                self.send_state(key=name)
        super(Widget, self).notify_change(change)

    def __repr__(self):
        return self._gen_repr_from_keys(self._repr_keys())

    #-------------------------------------------------------------------------
    # Support methods
    #-------------------------------------------------------------------------
    @contextmanager
    def _lock_property(self, **properties):
        """Lock a property-value pair.

        The value should be the JSON state of the property.

        NOTE: This, in addition to the single lock for all state changes, is
        flawed.  In the future we may want to look into buffering state changes
        back to the front-end."""
        self._property_lock = properties
        try:
            yield
        finally:
            self._property_lock = {}

    @contextmanager
    def hold_sync(self):
        """Hold syncing any state until the outermost context manager exits"""
        if self._holding_sync is True:
            yield
        else:
            try:
                self._holding_sync = True
                yield
            finally:
                self._holding_sync = False
                self.send_state(self._states_to_send)
                self._states_to_send.clear()

    def _should_send_property(self, key, value):
        """Check the property lock (property_lock)"""
        to_json = self.trait_metadata(key, 'to_json', self._trait_to_json)
        # A roundtrip conversion through json in the comparison takes care of
        # idiosyncracies of how python data structures map to json, for example
        # tuples get converted to lists.
        if (key in self._property_lock and jsonloads(
                jsondumps(to_json(value, self))) == self._property_lock[key]):
            return False
        elif self._holding_sync:
            self._states_to_send.add(key)
            return False
        else:
            return True

    # Event handlers
    @_show_traceback
    def _handle_msg(self, msg):
        """Called when a msg is received from the front-end"""
        data = msg['content']['data']
        method = data['method']

        if method == 'update':
            if 'state' in data:
                state = data['state']
                if 'buffer_paths' in data:
                    _put_buffers(state, data['buffer_paths'], msg['buffers'])
                self.set_state(state)

        # Handle a state request.
        elif method == 'request_state':
            self.send_state()

        # Handle a custom msg from the front-end.
        elif method == 'custom':
            if 'content' in data:
                self._handle_custom_msg(data['content'], msg['buffers'])

        # Catch remainder.
        else:
            self.log.error(
                'Unknown front-end to back-end widget msg with method "%s"' %
                method)

    def _handle_custom_msg(self, content, buffers):
        """Called when a custom msg is received."""
        self._msg_callbacks(self, content, buffers)

    def _handle_displayed(self, **kwargs):
        """Called when a view has been displayed for this widget instance"""
        self.comm.comm_log.write("widget::_handle_displayed !!!!!!!!!!\n")
        self.comm.comm_log.flush()
        self._display_callbacks(self, **kwargs)

    @staticmethod
    def _trait_to_json(x, self):
        """Convert a trait value to json."""
        return x

    @staticmethod
    def _trait_from_json(x, self):
        """Convert json values to objects."""
        return x

    def _ipython_display_(self, **kwargs):
        """Called when `IPython.display.display` is called on the widget."""

        self.comm.comm_log.write("widget::_ipython_display !!!!!!!!!!\n")
        self.comm.comm_log.flush()
        if self._view_name is not None:

            # The 'application/vnd.jupyter.widget-view+json' mimetype has not been registered yet.
            # See the registration process and naming convention at
            # http://tools.ietf.org/html/rfc6838
            # and the currently registered mimetypes at
            # http://www.iana.org/assignments/media-types/media-types.xhtml.
            data = {
                'text/plain': "A Jupyter Widget",
                'application/vnd.jupyter.widget-view+json': {
                    'version_major': 2,
                    'version_minor': 0,
                    'model_id': self._model_id
                }
            }
            self.comm.comm_log.write(
                "Calling display(then _handle_displayed) with data -> %s\n" %
                data)
            self.comm.comm_log.flush()
            display(data, raw=True)

            self._handle_displayed(**kwargs)

    def _send(self, msg, buffers=None):
        """Sends a message to the model in the front-end."""
        if self.comm is not None and self.comm.kernel is not None:
            self.comm.send(data=msg, buffers=buffers)

    def _repr_keys(self):
        traits = self.traits()
        for key in sorted(self.keys):
            # Exclude traits that start with an underscore
            if key[0] == '_':
                continue
            # Exclude traits who are equal to their default value
            value = getattr(self, key)
            trait = traits[key]
            if self._compare(value, trait.default_value):
                continue
            elif (isinstance(trait, (Container, Dict))
                  and trait.default_value == Undefined and len(value) == 0):
                # Empty container, and dynamic default will be empty
                continue
            yield key

    def _gen_repr_from_keys(self, keys):
        class_name = self.__class__.__name__
        signature = ', '.join('%s=%r' % (key, getattr(self, key))
                              for key in keys)
        return '%s(%s)' % (class_name, signature)
Exemple #18
0
class EnterpriseGatewayApp(JupyterApp):
    """Application that provisions Jupyter kernels and proxies HTTP/Websocket
    traffic to the kernels.

    - reads command line and environment variable settings
    - initializes managers and routes
    - creates a Tornado HTTP server
    - starts the Tornado event loop
    """
    name = 'jupyter-enterprise-gateway'
    version = __version__
    description = """
        Jupyter Enterprise Gateway

        Provisions remote Jupyter kernels and proxies HTTP/Websocket traffic to them.
    """

    # Also include when generating help options
    classes = [FileKernelSessionManager, RemoteMappingKernelManager]

    # Enable some command line shortcuts
    aliases = aliases

    # Server IP / PORT binding
    port_env = 'EG_PORT'
    port_default_value = 8888
    port = Integer(port_default_value,
                   config=True,
                   help='Port on which to listen (EG_PORT env var)')

    @default('port')
    def port_default(self):
        return int(
            os.getenv(self.port_env,
                      os.getenv('KG_PORT', self.port_default_value)))

    port_retries_env = 'EG_PORT_RETRIES'
    port_retries_default_value = 50
    port_retries = Integer(
        port_retries_default_value,
        config=True,
        help="""Number of ports to try if the specified port is not available
                           (EG_PORT_RETRIES env var)""")

    @default('port_retries')
    def port_retries_default(self):
        return int(
            os.getenv(
                self.port_retries_env,
                os.getenv('KG_PORT_RETRIES', self.port_retries_default_value)))

    ip_env = 'EG_IP'
    ip_default_value = '127.0.0.1'
    ip = Unicode(ip_default_value,
                 config=True,
                 help='IP address on which to listen (EG_IP env var)')

    @default('ip')
    def ip_default(self):
        return os.getenv(self.ip_env, os.getenv('KG_IP',
                                                self.ip_default_value))

    # Base URL
    base_url_env = 'EG_BASE_URL'
    base_url_default_value = '/'
    base_url = Unicode(
        base_url_default_value,
        config=True,
        help=
        'The base path for mounting all API resources (EG_BASE_URL env var)')

    @default('base_url')
    def base_url_default(self):
        return os.getenv(self.base_url_env,
                         os.getenv('KG_BASE_URL', self.base_url_default_value))

    # Token authorization
    auth_token_env = 'EG_AUTH_TOKEN'
    auth_token = Unicode(
        config=True,
        help=
        'Authorization token required for all requests (EG_AUTH_TOKEN env var)'
    )

    @default('auth_token')
    def _auth_token_default(self):
        return os.getenv(self.auth_token_env, os.getenv('KG_AUTH_TOKEN', ''))

    # Begin CORS headers
    allow_credentials_env = 'EG_ALLOW_CREDENTIALS'
    allow_credentials = Unicode(
        config=True,
        help=
        'Sets the Access-Control-Allow-Credentials header. (EG_ALLOW_CREDENTIALS env var)'
    )

    @default('allow_credentials')
    def allow_credentials_default(self):
        return os.getenv(self.allow_credentials_env,
                         os.getenv('KG_ALLOW_CREDENTIALS', ''))

    allow_headers_env = 'EG_ALLOW_HEADERS'
    allow_headers = Unicode(
        config=True,
        help=
        'Sets the Access-Control-Allow-Headers header. (EG_ALLOW_HEADERS env var)'
    )

    @default('allow_headers')
    def allow_headers_default(self):
        return os.getenv(self.allow_headers_env,
                         os.getenv('KG_ALLOW_HEADERS', ''))

    allow_methods_env = 'EG_ALLOW_METHODS'
    allow_methods = Unicode(
        config=True,
        help=
        'Sets the Access-Control-Allow-Methods header. (EG_ALLOW_METHODS env var)'
    )

    @default('allow_methods')
    def allow_methods_default(self):
        return os.getenv(self.allow_methods_env,
                         os.getenv('KG_ALLOW_METHODS', ''))

    allow_origin_env = 'EG_ALLOW_ORIGIN'
    allow_origin = Unicode(
        config=True,
        help=
        'Sets the Access-Control-Allow-Origin header. (EG_ALLOW_ORIGIN env var)'
    )

    @default('allow_origin')
    def allow_origin_default(self):
        return os.getenv(self.allow_origin_env,
                         os.getenv('KG_ALLOW_ORIGIN', ''))

    expose_headers_env = 'EG_EXPOSE_HEADERS'
    expose_headers = Unicode(
        config=True,
        help=
        'Sets the Access-Control-Expose-Headers header. (EG_EXPOSE_HEADERS env var)'
    )

    @default('expose_headers')
    def expose_headers_default(self):
        return os.getenv(self.expose_headers_env,
                         os.getenv('KG_EXPOSE_HEADERS', ''))

    trust_xheaders_env = 'EG_TRUST_XHEADERS'
    trust_xheaders = CBool(
        False,
        config=True,
        help="""Use x-* header values for overriding the remote-ip, useful when
                           application is behing a proxy. (EG_TRUST_XHEADERS env var)"""
    )

    @default('trust_xheaders')
    def trust_xheaders_default(self):
        return strtobool(
            os.getenv(self.trust_xheaders_env,
                      os.getenv('KG_TRUST_XHEADERS', 'False')))

    certfile_env = 'EG_CERTFILE'
    certfile = Unicode(
        None,
        config=True,
        allow_none=True,
        help=
        'The full path to an SSL/TLS certificate file. (EG_CERTFILE env var)')

    @default('certfile')
    def certfile_default(self):
        return os.getenv(self.certfile_env, os.getenv('KG_CERTFILE'))

    keyfile_env = 'EG_KEYFILE'
    keyfile = Unicode(
        None,
        config=True,
        allow_none=True,
        help=
        'The full path to a private key file for usage with SSL/TLS. (EG_KEYFILE env var)'
    )

    @default('keyfile')
    def keyfile_default(self):
        return os.getenv(self.keyfile_env, os.getenv('KG_KEYFILE'))

    client_ca_env = 'EG_CLIENT_CA'
    client_ca = Unicode(
        None,
        config=True,
        allow_none=True,
        help="""The full path to a certificate authority certificate for SSL/TLS
                        client authentication. (EG_CLIENT_CA env var)""")

    @default('client_ca')
    def client_ca_default(self):
        return os.getenv(self.client_ca_env, os.getenv('KG_CLIENT_CA'))

    max_age_env = 'EG_MAX_AGE'
    max_age = Unicode(
        config=True,
        help='Sets the Access-Control-Max-Age header. (EG_MAX_AGE env var)')

    @default('max_age')
    def max_age_default(self):
        return os.getenv(self.max_age_env, os.getenv('KG_MAX_AGE', ''))

    # End CORS headers

    max_kernels_env = 'EG_MAX_KERNELS'
    max_kernels = Integer(
        None,
        config=True,
        allow_none=True,
        help=
        """Limits the number of kernel instances allowed to run by this gateway.
                          Unbounded by default. (EG_MAX_KERNELS env var)""")

    @default('max_kernels')
    def max_kernels_default(self):
        val = os.getenv(self.max_kernels_env, os.getenv('KG_MAX_KERNELS'))
        return val if val is None else int(val)

    default_kernel_name_env = 'EG_DEFAULT_KERNEL_NAME'
    default_kernel_name = Unicode(
        config=True,
        help=
        'Default kernel name when spawning a kernel (EG_DEFAULT_KERNEL_NAME env var)'
    )

    @default('default_kernel_name')
    def default_kernel_name_default(self):
        # defaults to Jupyter's default kernel name on empty string
        return os.getenv(self.default_kernel_name_env,
                         os.getenv('KG_DEFAULT_KERNEL_NAME', ''))

    list_kernels_env = 'EG_LIST_KERNELS'
    list_kernels = Bool(
        config=True,
        help=
        """Permits listing of the running kernels using API endpoints /api/kernels
                        and /api/sessions. (EG_LIST_KERNELS env var) Note: Jupyter Notebook
                        allows this by default but Jupyter Enterprise Gateway does not."""
    )

    @default('list_kernels')
    def list_kernels_default(self):
        return os.getenv(self.list_kernels_env,
                         os.getenv('KG_LIST_KERNELS',
                                   'False')).lower() == 'true'

    env_whitelist_env = 'EG_ENV_WHITELIST'
    env_whitelist = List(
        config=True,
        help="""Environment variables allowed to be set when a client requests a
                         new kernel. (EG_ENV_WHITELIST env var)""")

    @default('env_whitelist')
    def env_whitelist_default(self):
        return os.getenv(self.env_whitelist_env,
                         os.getenv('KG_ENV_WHITELIST', '')).split(',')

    env_process_whitelist_env = 'EG_ENV_PROCESS_WHITELIST'
    env_process_whitelist = List(
        config=True,
        help="""Environment variables allowed to be inherited
                                 from the spawning process by the kernel. (EG_ENV_PROCESS_WHITELIST env var)"""
    )

    @default('env_process_whitelist')
    def env_process_whitelist_default(self):
        return os.getenv(self.env_process_whitelist_env,
                         os.getenv('KG_ENV_PROCESS_WHITELIST', '')).split(',')

    # Remote hosts
    remote_hosts_env = 'EG_REMOTE_HOSTS'
    remote_hosts_default_value = 'localhost'
    remote_hosts = List(
        default_value=[remote_hosts_default_value],
        config=True,
        help=
        """Bracketed comma-separated list of hosts on which DistributedProcessProxy
                        kernels will be launched e.g., ['host1','host2']. (EG_REMOTE_HOSTS env var
                        - non-bracketed, just comma-separated)""")

    @default('remote_hosts')
    def remote_hosts_default(self):
        return os.getenv(self.remote_hosts_env,
                         self.remote_hosts_default_value).split(',')

    # Yarn endpoint
    yarn_endpoint_env = 'EG_YARN_ENDPOINT'
    yarn_endpoint = Unicode(
        None,
        config=True,
        allow_none=True,
        help=
        """The http url specifying the YARN Resource Manager. Note: If this value is NOT set,
                            the YARN library will use the files within the local HADOOP_CONFIG_DIR to determine the
                            active resource manager. (EG_YARN_ENDPOINT env var)"""
    )

    @default('yarn_endpoint')
    def yarn_endpoint_default(self):
        return os.getenv(self.yarn_endpoint_env)

    # Alt Yarn endpoint
    alt_yarn_endpoint_env = 'EG_ALT_YARN_ENDPOINT'
    alt_yarn_endpoint = Unicode(
        None,
        config=True,
        allow_none=True,
        help=
        """The http url specifying the alternate YARN Resource Manager.  This value should
                                be set when YARN Resource Managers are configured for high availability.  Note: If both
                                YARN endpoints are NOT set, the YARN library will use the files within the local
                                HADOOP_CONFIG_DIR to determine the active resource manager.
                                (EG_ALT_YARN_ENDPOINT env var)""")

    @default('alt_yarn_endpoint')
    def alt_yarn_endpoint_default(self):
        return os.getenv(self.alt_yarn_endpoint_env)

    yarn_endpoint_security_enabled_env = 'EG_YARN_ENDPOINT_SECURITY_ENABLED'
    yarn_endpoint_security_enabled_default_value = False
    yarn_endpoint_security_enabled = Bool(
        yarn_endpoint_security_enabled_default_value,
        config=True,
        help="""Is YARN Kerberos/SPNEGO Security enabled (True/False).
                                          (EG_YARN_ENDPOINT_SECURITY_ENABLED env var)"""
    )

    @default('yarn_endpoint_security_enabled')
    def yarn_endpoint_security_enabled_default(self):
        return bool(
            os.getenv(self.yarn_endpoint_security_enabled_env,
                      self.yarn_endpoint_security_enabled_default_value))

    # Conductor endpoint
    conductor_endpoint_env = 'EG_CONDUCTOR_ENDPOINT'
    conductor_endpoint_default_value = None
    conductor_endpoint = Unicode(
        conductor_endpoint_default_value,
        config=True,
        help="""The http url for accessing the Conductor REST API.
                                 (EG_CONDUCTOR_ENDPOINT env var)""")

    @default('conductor_endpoint')
    def conductor_endpoint_default(self):
        return os.getenv(self.conductor_endpoint_env,
                         self.conductor_endpoint_default_value)

    _log_formatter_cls = LogFormatter  # traitlet default is LevelFormatter

    @default('log_format')
    def _default_log_format(self):
        """override default log format to include milliseconds"""
        return u"%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s]%(end_color)s %(message)s"

    # Impersonation enabled
    impersonation_enabled_env = 'EG_IMPERSONATION_ENABLED'
    impersonation_enabled = Bool(
        False,
        config=True,
        help=
        """Indicates whether impersonation will be performed during kernel launch.
                                 (EG_IMPERSONATION_ENABLED env var)""")

    @default('impersonation_enabled')
    def impersonation_enabled_default(self):
        return bool(
            os.getenv(self.impersonation_enabled_env, 'false').lower() ==
            'true')

    # Unauthorized users
    unauthorized_users_env = 'EG_UNAUTHORIZED_USERS'
    unauthorized_users_default_value = 'root'
    unauthorized_users = Set(
        default_value={unauthorized_users_default_value},
        config=True,
        help=
        """Comma-separated list of user names (e.g., ['root','admin']) against which
                             KERNEL_USERNAME will be compared.  Any match (case-sensitive) will prevent the
                             kernel's launch and result in an HTTP 403 (Forbidden) error.
                             (EG_UNAUTHORIZED_USERS env var - non-bracketed, just comma-separated)"""
    )

    @default('unauthorized_users')
    def unauthorized_users_default(self):
        return os.getenv(self.unauthorized_users_env,
                         self.unauthorized_users_default_value).split(',')

    # Authorized users
    authorized_users_env = 'EG_AUTHORIZED_USERS'
    authorized_users = Set(
        config=True,
        help=
        """Comma-separated list of user names (e.g., ['bob','alice']) against which
                           KERNEL_USERNAME will be compared.  Any match (case-sensitive) will allow the kernel's
                           launch, otherwise an HTTP 403 (Forbidden) error will be raised.  The set of unauthorized
                           users takes precedence. This option should be used carefully as it can dramatically limit
                           who can launch kernels.  (EG_AUTHORIZED_USERS env var - non-bracketed,
                           just comma-separated)""")

    @default('authorized_users')
    def authorized_users_default(self):
        au_env = os.getenv(self.authorized_users_env)
        return au_env.split(',') if au_env is not None else []

    # Port range
    port_range_env = 'EG_PORT_RANGE'
    port_range_default_value = "0..0"
    port_range = Unicode(
        port_range_default_value,
        config=True,
        help=
        """Specifies the lower and upper port numbers from which ports are created.
                         The bounded values are separated by '..' (e.g., 33245..34245 specifies a range of 1000 ports
                         to be randomly selected). A range of zero (e.g., 33245..33245 or 0..0) disables port-range
                         enforcement.  (EG_PORT_RANGE env var)""")

    @default('port_range')
    def port_range_default(self):
        return os.getenv(self.port_range_env, self.port_range_default_value)

    # Max Kernels per User
    max_kernels_per_user_env = 'EG_MAX_KERNELS_PER_USER'
    max_kernels_per_user_default_value = -1
    max_kernels_per_user = Integer(
        max_kernels_per_user_default_value,
        config=True,
        help="""Specifies the maximum number of kernels a user can have active
                                   simultaneously.  A value of -1 disables enforcement.
                                   (EG_MAX_KERNELS_PER_USER env var)""")

    @default('max_kernels_per_user')
    def max_kernels_per_user_default(self):
        return int(
            os.getenv(self.max_kernels_per_user_env,
                      self.max_kernels_per_user_default_value))

    ws_ping_interval_env = 'EG_WS_PING_INTERVAL_SECS'
    ws_ping_interval_default_value = 30
    ws_ping_interval = Integer(
        ws_ping_interval_default_value,
        config=True,
        help=
        """Specifies the ping interval(in seconds) that should be used by zmq port
                                     associated withspawned kernels.Set this variable to 0 to disable ping mechanism.
                                    (EG_WS_PING_INTERVAL_SECS env var)""")

    @default('ws_ping_interval')
    def ws_ping_interval_default(self):
        return int(
            os.getenv(self.ws_ping_interval_env,
                      self.ws_ping_interval_default_value))

    kernel_spec_manager = Instance(KernelSpecManager, allow_none=True)

    kernel_spec_manager_class = Type(default_value=KernelSpecManager,
                                     config=True,
                                     help="""
        The kernel spec manager class to use. Must be a subclass
        of `jupyter_client.kernelspec.KernelSpecManager`.
        """)

    kernel_manager_class = Type(klass=MappingKernelManager,
                                default_value=RemoteMappingKernelManager,
                                config=True,
                                help="""
        The kernel manager class to use. Must be a subclass
        of `notebook.services.kernels.MappingKernelManager`.
        """)

    kernel_session_manager_class = Type(klass=KernelSessionManager,
                                        default_value=FileKernelSessionManager,
                                        config=True,
                                        help="""
        The kernel session manager class to use. Must be a subclass
        of `enterprise_gateway.services.sessions.KernelSessionManager`.
        """)

    def initialize(self, argv=None):
        """Initializes the base class, configurable manager instances, the
        Tornado web app, and the tornado HTTP server.

        Parameters
        ----------
        argv
            Command line arguments
        """
        super(EnterpriseGatewayApp, self).initialize(argv)
        self.init_configurables()
        self.init_webapp()
        self.init_http_server()

    def init_configurables(self):
        """Initializes all configurable objects including a kernel manager, kernel
        spec manager, session manager, and personality.
        """
        self.kernel_spec_manager = KernelSpecManager(parent=self)

        # Only pass a default kernel name when one is provided. Otherwise,
        # adopt whatever default the kernel manager wants to use.
        kwargs = {}
        if self.default_kernel_name:
            kwargs['default_kernel_name'] = self.default_kernel_name

        self.kernel_spec_manager = self.kernel_spec_manager_class(
            parent=self, )

        self.kernel_manager = self.kernel_manager_class(
            parent=self,
            log=self.log,
            connection_dir=self.runtime_dir,
            kernel_spec_manager=self.kernel_spec_manager,
            **kwargs)

        self.session_manager = SessionManager(
            log=self.log, kernel_manager=self.kernel_manager)

        self.kernel_session_manager = self.kernel_session_manager_class(
            parent=self,
            log=self.log,
            kernel_manager=self.kernel_manager,
            config=self.config,  # required to get command-line options visible
            **kwargs)

        # Attempt to start persisted sessions
        self.kernel_session_manager.start_sessions()

        self.contents_manager = None  # Gateways don't use contents manager

    def _create_request_handlers(self):
        """Create default Jupyter handlers and redefine them off of the
        base_url path. Assumes init_configurables() has already been called.
        """
        handlers = []

        # append tuples for the standard kernel gateway endpoints
        for handler in (default_api_handlers + default_kernel_handlers +
                        default_kernelspec_handlers +
                        default_session_handlers + default_base_handlers):
            # Create a new handler pattern rooted at the base_url
            pattern = url_path_join('/', self.base_url, handler[0])
            # Some handlers take args, so retain those in addition to the
            # handler class ref
            new_handler = tuple([pattern] + list(handler[1:]))
            handlers.append(new_handler)
        return handlers

    def init_webapp(self):
        """Initializes Tornado web application with uri handlers.

        Adds the various managers and web-front configuration values to the
        Tornado settings for reference by the handlers.
        """
        # Enable the same pretty logging the notebook uses
        enable_pretty_logging()

        # Configure the tornado logging level too
        logging.getLogger().setLevel(self.log_level)

        handlers = self._create_request_handlers()

        self.web_app = web.Application(
            handlers=handlers,
            kernel_manager=self.kernel_manager,
            session_manager=self.session_manager,
            contents_manager=self.contents_manager,
            kernel_spec_manager=self.kernel_spec_manager,
            eg_auth_token=self.auth_token,
            eg_allow_credentials=self.allow_credentials,
            eg_allow_headers=self.allow_headers,
            eg_allow_methods=self.allow_methods,
            eg_allow_origin=self.allow_origin,
            eg_expose_headers=self.expose_headers,
            eg_max_age=self.max_age,
            eg_max_kernels=self.max_kernels,
            eg_env_process_whitelist=self.env_process_whitelist,
            eg_env_whitelist=self.env_whitelist,
            eg_list_kernels=self.list_kernels,
            # Also set the allow_origin setting used by notebook so that the
            # check_origin method used everywhere respects the value
            allow_origin=self.allow_origin,
            # Always allow remote access (has been limited to localhost >= notebook 5.6)
            allow_remote_access=True,
            # setting ws_ping_interval value that can allow it to be modified for the purpose of toggling ping mechanism
            # for zmq web-sockets or increasing/decreasing web socket ping interval/timeouts.
            ws_ping_interval=self.ws_ping_interval * 1000)

    def _build_ssl_options(self):
        """Build a dictionary of SSL options for the tornado HTTP server.

        Taken directly from jupyter/notebook code.
        """
        ssl_options = {}
        if self.certfile:
            ssl_options['certfile'] = self.certfile
        if self.keyfile:
            ssl_options['keyfile'] = self.keyfile
        if self.client_ca:
            ssl_options['ca_certs'] = self.client_ca
        if not ssl_options:
            # None indicates no SSL config
            ssl_options = None
        else:
            # SSL may be missing, so only import it if it's to be used
            import ssl
            # PROTOCOL_TLS selects the highest ssl/tls protocol version that both the client and
            # server support. When PROTOCOL_TLS is not available use PROTOCOL_SSLv23.
            # PROTOCOL_TLS is new in version 2.7.13, 3.5.3 and 3.6
            ssl_options.setdefault(
                'ssl_version', getattr(ssl, 'PROTOCOL_TLS',
                                       ssl.PROTOCOL_SSLv23))
            if ssl_options.get('ca_certs', False):
                ssl_options.setdefault('cert_reqs', ssl.CERT_REQUIRED)

        return ssl_options

    def init_http_server(self):
        """Initializes a HTTP server for the Tornado web application on the
        configured interface and port.

        Tries to find an open port if the one configured is not available using
        the same logic as the Jupyer Notebook server.
        """
        ssl_options = self._build_ssl_options()
        self.http_server = httpserver.HTTPServer(self.web_app,
                                                 xheaders=self.trust_xheaders,
                                                 ssl_options=ssl_options)

        for port in random_ports(self.port, self.port_retries + 1):
            try:
                self.http_server.listen(port, self.ip)
            except socket.error as e:
                if e.errno == errno.EADDRINUSE:
                    self.log.info(
                        'The port %i is already in use, trying another port.' %
                        port)
                    continue
                elif e.errno in (errno.EACCES,
                                 getattr(errno, 'WSAEACCES', errno.EACCES)):
                    self.log.warning("Permission to listen on port %i denied" %
                                     port)
                    continue
                else:
                    raise
            else:
                self.port = port
                break
        else:
            self.log.critical(
                'ERROR: the gateway server could not be started because '
                'no available port could be found.')
            self.exit(1)

    def start(self):
        """Starts an IO loop for the application. """

        super(EnterpriseGatewayApp, self).start()

        self.log.info(
            'Jupyter Enterprise Gateway {} is available at http{}://{}:{}'.
            format(EnterpriseGatewayApp.version, 's' if self.keyfile else '',
                   self.ip, self.port))
        # If impersonation is enabled, issue a warning message if the gateway user is not in unauthorized_users.
        if self.impersonation_enabled:
            gateway_user = getpass.getuser()
            if gateway_user.lower() not in self.unauthorized_users:
                self.log.warning(
                    "Impersonation is enabled and gateway user '{}' is NOT specified in the set of "
                    "unauthorized users!  Kernels may execute as that user with elevated privileges."
                    .format(gateway_user))

        self.io_loop = ioloop.IOLoop.current()

        if sys.platform != 'win32':
            signal.signal(signal.SIGHUP, signal.SIG_IGN)

        signal.signal(signal.SIGTERM, self._signal_stop)

        try:
            self.io_loop.start()
        except KeyboardInterrupt:
            self.log.info("Interrupted...")
            # Ignore further interrupts (ctrl-c)
            signal.signal(signal.SIGINT, signal.SIG_IGN)
        finally:
            self.shutdown()

    def shutdown(self):
        """Shuts down all running kernels."""
        kids = self.kernel_manager.list_kernel_ids()
        for kid in kids:
            self.kernel_manager.shutdown_kernel(kid, now=True)

    def stop(self):
        """
        Stops the HTTP server and IO loop associated with the application.
        """
        def _stop():
            self.http_server.stop()
            self.io_loop.stop()

        self.io_loop.add_callback(_stop)

    def _signal_stop(self, sig, frame):
        self.log.info("Received signal to terminate Enterprise Gateway.")
        self.io_loop.add_callback_from_signal(self.io_loop.stop)
class Session(Configurable):
    """Object for handling serialization and sending of messages.

    The Session object handles building messages and sending them
    with ZMQ sockets or ZMQStream objects.  Objects can communicate with each
    other over the network via Session objects, and only need to work with the
    dict-based IPython message spec. The Session will handle
    serialization/deserialization, security, and metadata.

    Sessions support configurable serialization via packer/unpacker traits,
    and signing with HMAC digests via the key/keyfile traits.

    Parameters
    ----------

    debug : bool
        whether to trigger extra debugging statements
    packer/unpacker : str : 'json', 'pickle' or import_string
        importstrings for methods to serialize message parts.  If just
        'json' or 'pickle', predefined JSON and pickle packers will be used.
        Otherwise, the entire importstring must be used.

        The functions must accept at least valid JSON input, and output *bytes*.

        For example, to use msgpack:
        packer = 'msgpack.packb', unpacker='msgpack.unpackb'
    pack/unpack : callables
        You can also set the pack/unpack callables for serialization directly.
    session : bytes
        the ID of this Session object.  The default is to generate a new UUID.
    username : unicode
        username added to message headers.  The default is to ask the OS.
    key : bytes
        The key used to initialize an HMAC signature.  If unset, messages
        will not be signed or checked.
    keyfile : filepath
        The file containing a key.  If this is set, `key` will be initialized
        to the contents of the file.

    """

    debug = Bool(False, config=True, help="""Debug output in the Session""")

    check_pid = Bool(True, config=True,
        help="""Whether to check PID to protect against calls after fork.

        This check can be disabled if fork-safety is handled elsewhere.
        """)

    packer = DottedObjectName('json',config=True,
            help="""The name of the packer for serializing messages.
            Should be one of 'json', 'pickle', or an import name
            for a custom callable serializer.""")

    @observe('packer')
    def _packer_changed(self, change):
        new = change['new']
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
            self.unpacker = new
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
            self.unpacker = new
        else:
            self.pack = import_item(str(new))

    unpacker = DottedObjectName('json', config=True,
        help="""The name of the unpacker for unserializing messages.
        Only used with custom functions for `packer`.""")

    @observe('unpacker')
    def _unpacker_changed(self, change):
        new = change['new']
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
            self.packer = new
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
            self.packer = new
        else:
            self.unpack = import_item(str(new))

    session = CUnicode(u'', config=True,
        help="""The UUID identifying this session.""")
    def _session_default(self):
        u = new_id()
        self.bsession = u.encode('ascii')
        return u

    @observe('session')
    def _session_changed(self, change):
        self.bsession = self.session.encode('ascii')

    # bsession is the session as bytes
    bsession = CBytes(b'')

    username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
        help="""Username for the Session. Default is your system username.""",
        config=True)

    metadata = Dict({}, config=True,
        help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")

    # if 0, no adapting to do.
    adapt_version = Integer(0)

    # message signature related traits:

    key = CBytes(config=True,
        help="""execution key, for signing messages.""")
    def _key_default(self):
        return new_id_bytes()

    @observe('key')
    def _key_changed(self, change):
        self._new_auth()

    signature_scheme = Unicode('hmac-sha256', config=True,
        help="""The digest scheme used to construct the message signatures.
        Must have the form 'hmac-HASH'.""")

    @observe('signature_scheme')
    def _signature_scheme_changed(self, change):
        new = change['new']
        if not new.startswith('hmac-'):
            raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
        hash_name = new.split('-', 1)[1]
        try:
            self.digest_mod = getattr(hashlib, hash_name)
        except AttributeError:
            raise TraitError("hashlib has no such attribute: %s" % hash_name)
        self._new_auth()

    digest_mod = Any()
    def _digest_mod_default(self):
        return hashlib.sha256

    auth = Instance(hmac.HMAC, allow_none=True)

    def _new_auth(self):
        if self.key:
            self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
        else:
            self.auth = None

    digest_history = Set()
    digest_history_size = Integer(2**16, config=True,
        help="""The maximum number of digests to remember.

        The digest history will be culled when it exceeds this value.
        """
    )

    keyfile = Unicode('', config=True,
        help="""path to file containing execution key.""")

    @observe('keyfile')
    def _keyfile_changed(self, change):
        with open(change['new'], 'rb') as f:
            self.key = f.read().strip()

    # for protecting against sends from forks
    pid = Integer()

    # serialization traits:

    pack = Any(default_packer) # the actual packer function

    @observe('pack')
    def _pack_changed(self, change):
        new = change['new']
        if not callable(new):
            raise TypeError("packer must be callable, not %s"%type(new))

    unpack = Any(default_unpacker) # the actual packer function

    @observe('unpack')
    def _unpack_changed(self, change):
        # unpacker is not checked - it is assumed to be
        new = change['new']
        if not callable(new):
            raise TypeError("unpacker must be callable, not %s"%type(new))

    # thresholds:
    copy_threshold = Integer(2**16, config=True,
        help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
    buffer_threshold = Integer(MAX_BYTES, config=True,
        help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
    item_threshold = Integer(MAX_ITEMS, config=True,
        help="""The maximum number of items for a container to be introspected for custom serialization.
        Containers larger than this are pickled outright.
        """
    )


    def __init__(self, **kwargs):
        """create a Session object

        Parameters
        ----------

        debug : bool
            whether to trigger extra debugging statements
        packer/unpacker : str : 'json', 'pickle' or import_string
            importstrings for methods to serialize message parts.  If just
            'json' or 'pickle', predefined JSON and pickle packers will be used.
            Otherwise, the entire importstring must be used.

            The functions must accept at least valid JSON input, and output
            *bytes*.

            For example, to use msgpack:
            packer = 'msgpack.packb', unpacker='msgpack.unpackb'
        pack/unpack : callables
            You can also set the pack/unpack callables for serialization
            directly.
        session : unicode (must be ascii)
            the ID of this Session object.  The default is to generate a new
            UUID.
        bsession : bytes
            The session as bytes
        username : unicode
            username added to message headers.  The default is to ask the OS.
        key : bytes
            The key used to initialize an HMAC signature.  If unset, messages
            will not be signed or checked.
        signature_scheme : str
            The message digest scheme. Currently must be of the form 'hmac-HASH',
            where 'HASH' is a hashing function available in Python's hashlib.
            The default is 'hmac-sha256'.
            This is ignored if 'key' is empty.
        keyfile : filepath
            The file containing a key.  If this is set, `key` will be
            initialized to the contents of the file.
        """
        super(Session, self).__init__(**kwargs)
        self._check_packers()
        self.none = self.pack({})
        # ensure self._session_default() if necessary, so bsession is defined:
        self.session
        self.pid = os.getpid()
        self._new_auth()
        if not self.key:
            get_logger().warning("Message signing is disabled.  This is insecure and not recommended!")

    def clone(self):
        """Create a copy of this Session

        Useful when connecting multiple times to a given kernel.
        This prevents a shared digest_history warning about duplicate digests
        due to multiple connections to IOPub in the same process.

        .. versionadded:: 5.1
        """
        # make a copy
        new_session = type(self)()
        for name in self.traits():
            setattr(new_session, name, getattr(self, name))
        # fork digest_history
        new_session.digest_history = set()
        new_session.digest_history.update(self.digest_history)
        return new_session

    @property
    def msg_id(self):
        """always return new uuid"""
        return new_id()

    def _check_packers(self):
        """check packers for datetime support."""
        pack = self.pack
        unpack = self.unpack

        # check simple serialization
        msg = dict(a=[1,'hi'])
        try:
            packed = pack(msg)
        except Exception as e:
            msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
            if self.packer == 'json':
                jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
            else:
                jsonmsg = ""
            raise ValueError(
                msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
            )

        # ensure packed message is bytes
        if not isinstance(packed, bytes):
            raise ValueError("message packed to %r, but bytes are required"%type(packed))

        # check that unpack is pack's inverse
        try:
            unpacked = unpack(packed)
            assert unpacked == msg
        except Exception as e:
            msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
            if self.packer == 'json':
                jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
            else:
                jsonmsg = ""
            raise ValueError(
                msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
            )

        # check datetime support
        msg = dict(t=utcnow())
        try:
            unpacked = unpack(pack(msg))
            if isinstance(unpacked['t'], datetime):
                raise ValueError("Shouldn't deserialize to datetime")
        except Exception:
            self.pack = lambda o: pack(squash_dates(o))
            self.unpack = lambda s: unpack(s)

    def msg_header(self, msg_type):
        return msg_header(self.msg_id, msg_type, self.username, self.session)

    def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
        """Return the nested message dict.

        This format is different from what is sent over the wire. The
        serialize/deserialize methods converts this nested message dict to the wire
        format, which is a list of message parts.
        """
        msg = {}
        header = self.msg_header(msg_type) if header is None else header
        msg['header'] = header
        msg['msg_id'] = header['msg_id']
        msg['msg_type'] = header['msg_type']
        msg['parent_header'] = {} if parent is None else extract_header(parent)
        msg['content'] = {} if content is None else content
        msg['metadata'] = self.metadata.copy()
        if metadata is not None:
            msg['metadata'].update(metadata)
        return msg

    def sign(self, msg_list):
        """Sign a message with HMAC digest. If no auth, return b''.

        Parameters
        ----------
        msg_list : list
            The [p_header,p_parent,p_content] part of the message list.
        """
        if self.auth is None:
            return b''
        h = self.auth.copy()
        for m in msg_list:
            h.update(m)
        return str_to_bytes(h.hexdigest())

    def serialize(self, msg, ident=None):
        """Serialize the message components to bytes.

        This is roughly the inverse of deserialize. The serialize/deserialize
        methods work with full message lists, whereas pack/unpack work with
        the individual message parts in the message list.

        Parameters
        ----------
        msg : dict or Message
            The next message dict as returned by the self.msg method.

        Returns
        -------
        msg_list : list
            The list of bytes objects to be sent with the format::

                [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
                 p_metadata, p_content, buffer1, buffer2, ...]

            In this list, the ``p_*`` entities are the packed or serialized
            versions, so if JSON is used, these are utf8 encoded JSON strings.
        """
        content = msg.get('content', {})
        if content is None:
            content = self.none
        elif isinstance(content, dict):
            content = self.pack(content)
        elif isinstance(content, bytes):
            # content is already packed, as in a relayed message
            pass
        elif isinstance(content, unicode_type):
            # should be bytes, but JSON often spits out unicode
            content = content.encode('utf8')
        else:
            raise TypeError("Content incorrect type: %s"%type(content))

        real_message = [self.pack(msg['header']),
                        self.pack(msg['parent_header']),
                        self.pack(msg['metadata']),
                        content,
        ]

        to_send = []

        if isinstance(ident, list):
            # accept list of idents
            to_send.extend(ident)
        elif ident is not None:
            to_send.append(ident)
        to_send.append(DELIM)

        signature = self.sign(real_message)
        to_send.append(signature)

        to_send.extend(real_message)

        return to_send

    def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
             buffers=None, track=False, header=None, metadata=None):
        """Build and send a message via stream or socket.

        The message format used by this function internally is as follows:

        [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
         buffer1,buffer2,...]

        The serialize/deserialize methods convert the nested message dict into this
        format.

        Parameters
        ----------

        stream : zmq.Socket or ZMQStream
            The socket-like object used to send the data.
        msg_or_type : str or Message/dict
            Normally, msg_or_type will be a msg_type unless a message is being
            sent more than once. If a header is supplied, this can be set to
            None and the msg_type will be pulled from the header.

        content : dict or None
            The content of the message (ignored if msg_or_type is a message).
        header : dict or None
            The header dict for the message (ignored if msg_to_type is a message).
        parent : Message or dict or None
            The parent or parent header describing the parent of this message
            (ignored if msg_or_type is a message).
        ident : bytes or list of bytes
            The zmq.IDENTITY routing path.
        metadata : dict or None
            The metadata describing the message
        buffers : list or None
            The already-serialized buffers to be appended to the message.
        track : bool
            Whether to track.  Only for use with Sockets, because ZMQStream
            objects cannot track messages.


        Returns
        -------
        msg : dict
            The constructed message.
        """
        if not isinstance(stream, zmq.Socket):
            # ZMQStreams and dummy sockets do not support tracking.
            track = False

        if isinstance(msg_or_type, (Message, dict)):
            # We got a Message or message dict, not a msg_type so don't
            # build a new Message.
            msg = msg_or_type
            buffers = buffers or msg.get('buffers', [])
        else:
            msg = self.msg(msg_or_type, content=content, parent=parent,
                           header=header, metadata=metadata)
        if self.check_pid and not os.getpid() == self.pid:
            get_logger().warning("WARNING: attempted to send message from fork\n%s",
                msg
            )
            return
        buffers = [] if buffers is None else buffers
        for idx, buf in enumerate(buffers):
            if isinstance(buf, memoryview):
                view = buf
            else:
                try:
                    # check to see if buf supports the buffer protocol.
                    view = memoryview(buf)
                except TypeError:
                    raise TypeError("Buffer objects must support the buffer protocol.")
            # memoryview.contiguous is new in 3.3,
            # just skip the check on Python 2
            if hasattr(view, 'contiguous') and not view.contiguous:
                # zmq requires memoryviews to be contiguous
                raise ValueError("Buffer %i (%r) is not contiguous" % (idx, buf))

        if self.adapt_version:
            msg = adapt(msg, self.adapt_version)
        to_send = self.serialize(msg, ident)
        to_send.extend(buffers)
        longest = max([ len(s) for s in to_send ])
        copy = (longest < self.copy_threshold)

        if buffers and track and not copy:
            # only really track when we are doing zero-copy buffers
            tracker = stream.send_multipart(to_send, copy=False, track=True)
        else:
            # use dummy tracker, which will be done immediately
            tracker = DONE
            stream.send_multipart(to_send, copy=copy)

        if self.debug:
            pprint.pprint(msg)
            pprint.pprint(to_send)
            pprint.pprint(buffers)

        msg['tracker'] = tracker

        return msg

    def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
        """Send a raw message via ident path.

        This method is used to send a already serialized message.

        Parameters
        ----------
        stream : ZMQStream or Socket
            The ZMQ stream or socket to use for sending the message.
        msg_list : list
            The serialized list of messages to send. This only includes the
            [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
            the message.
        ident : ident or list
            A single ident or a list of idents to use in sending.
        """
        to_send = []
        if isinstance(ident, bytes):
            ident = [ident]
        if ident is not None:
            to_send.extend(ident)

        to_send.append(DELIM)
        # Don't include buffers in signature (per spec).
        to_send.append(self.sign(msg_list[0:4]))
        to_send.extend(msg_list)
        stream.send_multipart(to_send, flags, copy=copy)

    def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
        """Receive and unpack a message.

        Parameters
        ----------
        socket : ZMQStream or Socket
            The socket or stream to use in receiving.

        Returns
        -------
        [idents], msg
            [idents] is a list of idents and msg is a nested message dict of
            same format as self.msg returns.
        """
        if isinstance(socket, ZMQStream):
            socket = socket.socket
        try:
            msg_list = socket.recv_multipart(mode, copy=copy)
        except zmq.ZMQError as e:
            if e.errno == zmq.EAGAIN:
                # We can convert EAGAIN to None as we know in this case
                # recv_multipart won't return None.
                return None,None
            else:
                raise
        # split multipart message into identity list and message dict
        # invalid large messages can cause very expensive string comparisons
        idents, msg_list = self.feed_identities(msg_list, copy)
        try:
            return idents, self.deserialize(msg_list, content=content, copy=copy)
        except Exception as e:
            # TODO: handle it
            raise e

    def feed_identities(self, msg_list, copy=True):
        """Split the identities from the rest of the message.

        Feed until DELIM is reached, then return the prefix as idents and
        remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
        but that would be silly.

        Parameters
        ----------
        msg_list : a list of Message or bytes objects
            The message to be split.
        copy : bool
            flag determining whether the arguments are bytes or Messages

        Returns
        -------
        (idents, msg_list) : two lists
            idents will always be a list of bytes, each of which is a ZMQ
            identity. msg_list will be a list of bytes or zmq.Messages of the
            form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
            should be unpackable/unserializable via self.deserialize at this
            point.
        """
        if copy:
            idx = msg_list.index(DELIM)
            return msg_list[:idx], msg_list[idx+1:]
        else:
            failed = True
            for idx,m in enumerate(msg_list):
                if m.bytes == DELIM:
                    failed = False
                    break
            if failed:
                raise ValueError("DELIM not in msg_list")
            idents, msg_list = msg_list[:idx], msg_list[idx+1:]
            return [m.bytes for m in idents], msg_list

    def _add_digest(self, signature):
        """add a digest to history to protect against replay attacks"""
        if self.digest_history_size == 0:
            # no history, never add digests
            return

        self.digest_history.add(signature)
        if len(self.digest_history) > self.digest_history_size:
            # threshold reached, cull 10%
            self._cull_digest_history()

    def _cull_digest_history(self):
        """cull the digest history

        Removes a randomly selected 10% of the digest history
        """
        current = len(self.digest_history)
        n_to_cull = max(int(current // 10), current - self.digest_history_size)
        if n_to_cull >= current:
            self.digest_history = set()
            return
        to_cull = random.sample(self.digest_history, n_to_cull)
        self.digest_history.difference_update(to_cull)

    def deserialize(self, msg_list, content=True, copy=True):
        """Unserialize a msg_list to a nested message dict.

        This is roughly the inverse of serialize. The serialize/deserialize
        methods work with full message lists, whereas pack/unpack work with
        the individual message parts in the message list.

        Parameters
        ----------
        msg_list : list of bytes or Message objects
            The list of message parts of the form [HMAC,p_header,p_parent,
            p_metadata,p_content,buffer1,buffer2,...].
        content : bool (True)
            Whether to unpack the content dict (True), or leave it packed
            (False).
        copy : bool (True)
            Whether msg_list contains bytes (True) or the non-copying Message
            objects in each place (False).

        Returns
        -------
        msg : dict
            The nested message dict with top-level keys [header, parent_header,
            content, buffers].  The buffers are returned as memoryviews.
        """
        minlen = 5
        message = {}
        if not copy:
            # pyzmq didn't copy the first parts of the message, so we'll do it
            for i in range(minlen):
                msg_list[i] = msg_list[i].bytes
        if self.auth is not None:
            signature = msg_list[0]
            if not signature:
                raise ValueError("Unsigned Message")
            if signature in self.digest_history:
                raise ValueError("Duplicate Signature: %r" % signature)
            if content:
                # Only store signature if we are unpacking content, don't store if just peeking.
                self._add_digest(signature)
            check = self.sign(msg_list[1:5])
            if not compare_digest(signature, check):
                raise ValueError("Invalid Signature: %r" % signature)
        if not len(msg_list) >= minlen:
            raise TypeError("malformed message, must have at least %i elements"%minlen)
        header = self.unpack(msg_list[1])
        message['header'] = extract_dates(header)
        message['msg_id'] = header['msg_id']
        message['msg_type'] = header['msg_type']
        message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
        message['metadata'] = self.unpack(msg_list[3])
        if content:
            message['content'] = self.unpack(msg_list[4])
        else:
            message['content'] = msg_list[4]
        buffers = [memoryview(b) for b in msg_list[5:]]
        if buffers and buffers[0].shape is None:
            # force copy to workaround pyzmq #646
            buffers = [memoryview(b.bytes) for b in msg_list[5:]]
        message['buffers'] = buffers
        if self.debug:
            pprint.pprint(message)
        # adapt to the current version
        return adapt(message)

    def unserialize(self, *args, **kwargs):
        warnings.warn(
            "Session.unserialize is deprecated. Use Session.deserialize.",
            DeprecationWarning,
        )
        return self.deserialize(*args, **kwargs)
Exemple #20
0
class Kernel(SingletonConfigurable):

    #---------------------------------------------------------------------------
    # Kernel interface
    #---------------------------------------------------------------------------

    # attribute to override with a GUI
    eventloop = Any(None)

    @observe('eventloop')
    def _update_eventloop(self, change):
        """schedule call to eventloop from IOLoop"""
        loop = ioloop.IOLoop.current()
        if change.new is not None:
            loop.add_callback(self.enter_eventloop)

    session = Instance(Session, allow_none=True)
    profile_dir = Instance('IPython.core.profiledir.ProfileDir',
                           allow_none=True)
    shell_stream = Instance(ZMQStream, allow_none=True)

    shell_streams = List(
        help="""Deprecated shell_streams alias. Use shell_stream

        .. versionchanged:: 6.0
            shell_streams is deprecated. Use shell_stream.
        """)

    @default("shell_streams")
    def _shell_streams_default(self):
        warnings.warn(
            "Kernel.shell_streams is deprecated in yapkernel 6.0. Use Kernel.shell_stream",
            DeprecationWarning,
            stacklevel=2,
        )
        if self.shell_stream is not None:
            return [self.shell_stream]
        else:
            return []

    @observe("shell_streams")
    def _shell_streams_changed(self, change):
        warnings.warn(
            "Kernel.shell_streams is deprecated in yapkernel 6.0. Use Kernel.shell_stream",
            DeprecationWarning,
            stacklevel=2,
        )
        if len(change.new) > 1:
            warnings.warn(
                "Kernel only supports one shell stream. Additional streams will be ignored.",
                RuntimeWarning,
                stacklevel=2,
            )
        if change.new:
            self.shell_stream = change.new[0]

    control_stream = Instance(ZMQStream, allow_none=True)

    debug_shell_socket = Any()

    control_thread = Any()
    iopub_socket = Any()
    iopub_thread = Any()
    stdin_socket = Any()
    log = Instance(logging.Logger, allow_none=True)

    # identities:
    int_id = Integer(-1)
    ident = Unicode()

    @default('ident')
    def _default_ident(self):
        return str(uuid.uuid4())

    # This should be overridden by wrapper kernels that implement any real
    # language.
    language_info = {
        'name': 'Prolog (YAP)',
        'mimetype': 'text/x-prolog',
        'file_extension': '.yap',
    }

    # any links that should go in the help menu
    help_links = List()

    # Private interface

    _darwin_app_nap = Bool(
        True,
        help="""Whether to use appnope for compatibility with OS X App Nap.

        Only affects OS X >= 10.9.
        """).tag(config=True)

    # track associations with current request
    _allow_stdin = Bool(False)
    _parents = Dict({"shell": {}, "control": {}})
    _parent_ident = Dict({'shell': b'', 'control': b''})

    @property
    def _parent_header(self):
        warnings.warn(
            "Kernel._parent_header is deprecated in yapkernel 6. Use .get_parent()",
            DeprecationWarning,
            stacklevel=2,
        )
        return self.get_parent(channel="shell")

    # Time to sleep after flushing the stdout/err buffers in each execute
    # cycle.  While this introduces a hard limit on the minimal latency of the
    # execute cycle, it helps prevent output synchronization problems for
    # clients.
    # Units are in seconds.  The minimum zmq latency on local host is probably
    # ~150 microseconds, set this to 500us for now.  We may need to increase it
    # a little if it's not enough after more interactive testing.
    _execute_sleep = Float(0.0005).tag(config=True)

    # Frequency of the kernel's event loop.
    # Units are in seconds, kernel subclasses for GUI toolkits may need to
    # adapt to milliseconds.
    _poll_interval = Float(0.01).tag(config=True)

    stop_on_error_timeout = Float(
        0.0,
        config=True,
        help="""time (in seconds) to wait for messages to arrive
        when aborting queued requests after an error.

        Requests that arrive within this window after an error
        will be cancelled.

        Increase in the event of unusually slow network
        causing significant delays,
        which can manifest as e.g. "Run all" in a notebook
        aborting some, but not all, messages after an error.
        """)

    # If the shutdown was requested over the network, we leave here the
    # necessary reply message so it can be sent by our registered atexit
    # handler.  This ensures that the reply is only sent to clients truly at
    # the end of our shutdown process (which happens after the underlying
    # IPython shell's own shutdown).
    _shutdown_message = None

    # This is a dict of port number that the kernel is listening on. It is set
    # by record_ports and used by connect_request.
    _recorded_ports = Dict()

    # set of aborted msg_ids
    aborted = Set()

    # Track execution count here. For IPython, we override this to use the
    # execution count we store in the shell.
    execution_count = 0

    msg_types = [
        'execute_request',
        'complete_request',
        'inspect_request',
        'history_request',
        'comm_info_request',
        'kernel_info_request',
        'connect_request',
        'shutdown_request',
        'is_complete_request',
        'interrupt_request',
        # deprecated:
        'apply_request',
    ]
    # add deprecated ipyparallel control messages
    control_msg_types = msg_types + [
        'clear_request', 'abort_request', 'debug_request'
    ]

    def __init__(self, **kwargs):
        super(Kernel, self).__init__(**kwargs)
        # Build dict of handlers for message types
        self.shell_handlers = {}
        for msg_type in self.msg_types:
            self.shell_handlers[msg_type] = getattr(self, msg_type)

        self.control_handlers = {}
        for msg_type in self.control_msg_types:
            self.control_handlers[msg_type] = getattr(self, msg_type)

        self.control_queue = Queue()

    def dispatch_control(self, msg):
        self.control_queue.put_nowait(msg)

    async def poll_control_queue(self):
        while True:
            msg = await self.control_queue.get()
            # handle tracers from _flush_control_queue
            if isinstance(msg, (concurrent.futures.Future, asyncio.Future)):
                msg.set_result(None)
                continue
            await self.process_control(msg)

    async def _flush_control_queue(self):
        """Flush the control queue, wait for processing of any pending messages"""
        if self.control_thread:
            control_loop = self.control_thread.io_loop
            # concurrent.futures.Futures are threadsafe
            # and can be used to await across threads
            tracer_future = concurrent.futures.Future()
            awaitable_future = asyncio.wrap_future(tracer_future)
        else:
            control_loop = self.io_loop
            tracer_future = awaitable_future = asyncio.Future()

        def _flush():
            # control_stream.flush puts messages on the queue
            self.control_stream.flush()
            # put Future on the queue after all of those,
            # so we can wait for all queued messages to be processed
            self.control_queue.put(tracer_future)

        control_loop.add_callback(_flush)
        return awaitable_future

    async def process_control(self, msg):
        """dispatch control requests"""
        idents, msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.deserialize(msg, content=True, copy=False)
        except Exception:
            self.log.error("Invalid Control Message", exc_info=True)
            return

        self.log.debug("Control received: %s", msg)

        # Set the parent message for side effects.
        self.set_parent(idents, msg, channel='control')
        self._publish_status('busy', 'control')

        header = msg['header']
        msg_type = header['msg_type']

        handler = self.control_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type)
        else:
            try:
                result = handler(self.control_stream, idents, msg)
                if inspect.isawaitable(result):
                    await result
            except Exception:
                self.log.error("Exception in control handler:", exc_info=True)

        sys.stdout.flush()
        sys.stderr.flush()
        self._publish_status('idle', 'control')
        # flush to ensure reply is sent
        self.control_stream.flush(zmq.POLLOUT)

    def should_handle(self, stream, msg, idents):
        """Check whether a shell-channel message should be handled

        Allows subclasses to prevent handling of certain messages (e.g. aborted requests).
        """
        msg_id = msg['header']['msg_id']
        if msg_id in self.aborted:
            # is it safe to assume a msg_id will not be resubmitted?
            self.aborted.remove(msg_id)
            self._send_abort_reply(stream, msg, idents)
            return False
        return True

    async def dispatch_shell(self, msg):
        """dispatch shell requests"""

        # flush control queue before handling shell requests
        await self._flush_control_queue()

        idents, msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.deserialize(msg, content=True, copy=False)
        except Exception:
            self.log.error("Invalid Message", exc_info=True)
            return

        # Set the parent message for side effects.
        self.set_parent(idents, msg, channel='shell')
        self._publish_status('busy', 'shell')

        msg_type = msg['header']['msg_type']

        # Only abort execute requests
        if self._aborting and msg_type == 'execute_request':
            self._send_abort_reply(self.shell_stream, msg, idents)
            self._publish_status('idle', 'shell')
            # flush to ensure reply is sent before
            # handling the next request
            self.shell_stream.flush(zmq.POLLOUT)
            return

        # Print some info about this message and leave a '--->' marker, so it's
        # easier to trace visually the message chain when debugging.  Each
        # handler prints its message at the end.
        self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type)
        self.log.debug('   Content: %s\n   --->\n   ', msg['content'])

        if not self.should_handle(self.shell_stream, msg, idents):
            return

        handler = self.shell_handlers.get(msg_type, None)
        if handler is None:
            self.log.warning("Unknown message type: %r", msg_type)
        else:
            self.log.debug("%s: %s", msg_type, msg)
            try:
                self.pre_handler_hook()
            except Exception:
                self.log.debug("Unable to signal in pre_handler_hook:",
                               exc_info=True)
            try:
                result = handler(self.shell_stream, idents, msg)
                if inspect.isawaitable(result):
                    await result
            except Exception:
                self.log.error("Exception in message handler:", exc_info=True)
            except KeyboardInterrupt:
                # Ctrl-c shouldn't crash the kernel here.
                self.log.error("KeyboardInterrupt caught in kernel.")
            finally:
                try:
                    self.post_handler_hook()
                except Exception:
                    self.log.debug("Unable to signal in post_handler_hook:",
                                   exc_info=True)

        sys.stdout.flush()
        sys.stderr.flush()
        self._publish_status('idle', 'shell')
        # flush to ensure reply is sent before
        # handling the next request
        self.shell_stream.flush(zmq.POLLOUT)

    def pre_handler_hook(self):
        """Hook to execute before calling message handler"""
        # ensure default_int_handler during handler call
        self.saved_sigint_handler = signal(SIGINT, default_int_handler)

    def post_handler_hook(self):
        """Hook to execute after calling message handler"""
        signal(SIGINT, self.saved_sigint_handler)

    def enter_eventloop(self):
        """enter eventloop"""
        self.log.info("Entering eventloop %s", self.eventloop)
        # record handle, so we can check when this changes
        eventloop = self.eventloop
        if eventloop is None:
            self.log.info("Exiting as there is no eventloop")
            return

        def advance_eventloop():
            # check if eventloop changed:
            if self.eventloop is not eventloop:
                self.log.info("exiting eventloop %s", eventloop)
                return
            if self.msg_queue.qsize():
                self.log.debug("Delaying eventloop due to waiting messages")
                # still messages to process, make the eventloop wait
                schedule_next()
                return
            self.log.debug("Advancing eventloop %s", eventloop)
            try:
                eventloop(self)
            except KeyboardInterrupt:
                # Ctrl-C shouldn't crash the kernel
                self.log.error("KeyboardInterrupt caught in kernel")
                pass
            if self.eventloop is eventloop:
                # schedule advance again
                schedule_next()

        def schedule_next():
            """Schedule the next advance of the eventloop"""
            # flush the eventloop every so often,
            # giving us a chance to handle messages in the meantime
            self.log.debug("Scheduling eventloop advance")
            self.io_loop.call_later(0.001, advance_eventloop)

        # begin polling the eventloop
        schedule_next()

    async def do_one_iteration(self):
        """Process a single shell message

        Any pending control messages will be flushed as well

        .. versionchanged:: 5
            This is now a coroutine
        """
        # flush messages off of shell stream into the message queue
        self.shell_stream.flush()
        # process at most one shell message per iteration
        await self.process_one(wait=False)

    async def process_one(self, wait=True):
        """Process one request

        Returns None if no message was handled.
        """
        if wait:
            t, dispatch, args = await self.msg_queue.get()
        else:
            try:
                t, dispatch, args = self.msg_queue.get_nowait()
            except asyncio.QueueEmpty:
                return None
        await dispatch(*args)

    async def dispatch_queue(self):
        """Coroutine to preserve order of message handling

        Ensures that only one message is processing at a time,
        even when the handler is async
        """

        while True:
            try:
                await self.process_one()
            except Exception:
                self.log.exception("Error in message handler")

    _message_counter = Any(help="""Monotonic counter of messages
        """, )

    @default('_message_counter')
    def _message_counter_default(self):
        return itertools.count()

    def schedule_dispatch(self, dispatch, *args):
        """schedule a message for dispatch"""
        idx = next(self._message_counter)

        self.msg_queue.put_nowait((
            idx,
            dispatch,
            args,
        ))
        # ensure the eventloop wakes up
        self.io_loop.add_callback(lambda: None)

    def start(self):
        """register dispatchers for streams"""
        self.io_loop = ioloop.IOLoop.current()
        self.msg_queue = Queue()
        self.io_loop.add_callback(self.dispatch_queue)

        self.control_stream.on_recv(self.dispatch_control, copy=False)

        if self.control_thread:
            control_loop = self.control_thread.io_loop
        else:
            control_loop = self.io_loop

        asyncio.run_coroutine_threadsafe(self.poll_control_queue(),
                                         control_loop.asyncio_loop)

        self.shell_stream.on_recv(
            partial(
                self.schedule_dispatch,
                self.dispatch_shell,
            ),
            copy=False,
        )

        # publish idle status
        self._publish_status('starting', 'shell')

    def record_ports(self, ports):
        """Record the ports that this kernel is using.

        The creator of the Kernel instance must call this methods if they
        want the :meth:`connect_request` method to return the port numbers.
        """
        self._recorded_ports = ports

    #---------------------------------------------------------------------------
    # Kernel request handlers
    #---------------------------------------------------------------------------

    def _publish_execute_input(self, code, parent, execution_count):
        """Publish the code request on the iopub stream."""

        self.session.send(self.iopub_socket,
                          'execute_input', {
                              'code': code,
                              'execution_count': execution_count
                          },
                          parent=parent,
                          ident=self._topic('execute_input'))

    def _publish_status(self, status, channel, parent=None):
        """send status (busy/idle) on IOPub"""
        self.session.send(
            self.iopub_socket,
            "status",
            {"execution_state": status},
            parent=parent or self.get_parent(channel),
            ident=self._topic("status"),
        )

    def _publish_debug_event(self, event):
        self.session.send(
            self.iopub_socket,
            "debug_event",
            event,
            parent=self.get_parent("control"),
            ident=self._topic("debug_event"),
        )

    def set_parent(self, ident, parent, channel='shell'):
        """Set the current parent request

        Side effects (IOPub messages) and replies are associated with
        the request that caused them via the parent_header.

        The parent identity is used to route input_request messages
        on the stdin channel.
        """
        self._parent_ident[channel] = ident
        self._parents[channel] = parent

    def get_parent(self, channel="shell"):
        """Get the parent request associated with a channel.

        .. versionadded:: 6

        Parameters
        ----------
        channel : str
            the name of the channel ('shell' or 'control')

        Returns
        -------
        message : dict
            the parent message for the most recent request on the channel.
        """
        return self._parents.get(channel, {})

    def send_response(self,
                      stream,
                      msg_or_type,
                      content=None,
                      ident=None,
                      buffers=None,
                      track=False,
                      header=None,
                      metadata=None,
                      channel='shell'):
        """Send a response to the message we're currently processing.

        This accepts all the parameters of :meth:`jupyter_client.session.Session.send`
        except ``parent``.

        This relies on :meth:`set_parent` having been called for the current
        message.
        """
        return self.session.send(
            stream,
            msg_or_type,
            content,
            self.get_parent(channel),
            ident,
            buffers,
            track,
            header,
            metadata,
        )

    def init_metadata(self, parent):
        """Initialize metadata.

        Run at the beginning of execution requests.
        """
        # FIXME: `started` is part of ipyparallel
        # Remove for yapkernel 5.0
        return {
            'started': now(),
        }

    def finish_metadata(self, parent, metadata, reply_content):
        """Finish populating metadata.

        Run after completing an execution request.
        """
        return metadata

    async def execute_request(self, stream, ident, parent):
        """handle an execute_request"""

        try:
            content = parent['content']
            code = content['code']
            silent = content['silent']
            store_history = content.get('store_history', not silent)
            user_expressions = content.get('user_expressions', {})
            allow_stdin = content.get('allow_stdin', False)
        except Exception:
            self.log.error("Got bad msg: ")
            self.log.error("%s", parent)
            return

        stop_on_error = content.get('stop_on_error', True)

        metadata = self.init_metadata(parent)

        # Re-broadcast our input for the benefit of listening clients, and
        # start computing output
        if not silent:
            self.execution_count += 1
            self._publish_execute_input(code, parent, self.execution_count)

        reply_content = self.do_execute(
            code,
            silent,
            store_history,
            user_expressions,
            allow_stdin,
        )
        if inspect.isawaitable(reply_content):
            reply_content = await reply_content

        # Flush output before sending the reply.
        sys.stdout.flush()
        sys.stderr.flush()
        # FIXME: on rare occasions, the flush doesn't seem to make it to the
        # clients... This seems to mitigate the problem, but we definitely need
        # to better understand what's going on.
        if self._execute_sleep:
            time.sleep(self._execute_sleep)

        # Send the reply.
        reply_content = json_clean(reply_content)
        metadata = self.finish_metadata(parent, metadata, reply_content)

        reply_msg = self.session.send(stream,
                                      'execute_reply',
                                      reply_content,
                                      parent,
                                      metadata=metadata,
                                      ident=ident)

        self.log.debug("%s", reply_msg)

        if not silent and reply_msg['content'][
                'status'] == 'error' and stop_on_error:
            await self._abort_queues()

    def do_execute(self,
                   code,
                   silent,
                   store_history=True,
                   user_expressions=None,
                   allow_stdin=False):
        """Execute user code. Must be overridden by subclasses.
        """
        raise NotImplementedError

    async def complete_request(self, stream, ident, parent):
        content = parent['content']
        code = content['code']
        cursor_pos = content['cursor_pos']

        matches = self.do_complete(code, cursor_pos)
        if inspect.isawaitable(matches):
            matches = await matches

        matches = json_clean(matches)
        self.session.send(stream, "complete_reply", matches, parent, ident)

    def do_complete(self, code, cursor_pos):
        """Override in subclasses to find completions.
        """
        return {
            'matches': [],
            'cursor_end': cursor_pos,
            'cursor_start': cursor_pos,
            'metadata': {},
            'status': 'ok'
        }

    async def inspect_request(self, stream, ident, parent):
        content = parent['content']

        reply_content = self.do_inspect(
            content['code'],
            content['cursor_pos'],
            content.get('detail_level', 0),
        )
        if inspect.isawaitable(reply_content):
            reply_content = await reply_content

        # Before we send this object over, we scrub it for JSON usage
        reply_content = json_clean(reply_content)
        msg = self.session.send(stream, 'inspect_reply', reply_content, parent,
                                ident)
        self.log.debug("%s", msg)

    def do_inspect(self, code, cursor_pos, detail_level=0):
        """Override in subclasses to allow introspection.
        """
        return {'status': 'ok', 'data': {}, 'metadata': {}, 'found': False}

    async def history_request(self, stream, ident, parent):
        content = parent['content']

        reply_content = self.do_history(**content)
        if inspect.isawaitable(reply_content):
            reply_content = await reply_content

        reply_content = json_clean(reply_content)
        msg = self.session.send(stream, 'history_reply', reply_content, parent,
                                ident)
        self.log.debug("%s", msg)

    def do_history(self,
                   hist_access_type,
                   output,
                   raw,
                   session=None,
                   start=None,
                   stop=None,
                   n=None,
                   pattern=None,
                   unique=False):
        """Override in subclasses to access history.
        """
        return {'status': 'ok', 'history': []}

    async def connect_request(self, stream, ident, parent):
        if self._recorded_ports is not None:
            content = self._recorded_ports.copy()
        else:
            content = {}
        content['status'] = 'ok'
        msg = self.session.send(stream, 'connect_reply', content, parent,
                                ident)
        self.log.debug("%s", msg)

    @property
    def kernel_info(self):
        return {
            'protocol_version': kernel_protocol_version,
            'implementation': self.implementation,
            'implementation_version': self.implementation_version,
            'language_info': self.language_info,
            'banner': self.banner,
            'help_links': self.help_links,
        }

    async def kernel_info_request(self, stream, ident, parent):
        content = {'status': 'ok'}
        content.update(self.kernel_info)
        msg = self.session.send(stream, 'kernel_info_reply', content, parent,
                                ident)
        self.log.debug("%s", msg)

    async def comm_info_request(self, stream, ident, parent):
        content = parent['content']
        target_name = content.get('target_name', None)

        # Should this be moved to ipkernel?
        if hasattr(self, 'comm_manager'):
            comms = {
                k: dict(target_name=v.target_name)
                for (k, v) in self.comm_manager.comms.items()
                if v.target_name == target_name or target_name is None
            }
        else:
            comms = {}
        reply_content = dict(comms=comms, status='ok')
        msg = self.session.send(stream, 'comm_info_reply', reply_content,
                                parent, ident)
        self.log.debug("%s", msg)

    async def interrupt_request(self, stream, ident, parent):
        pid = os.getpid()
        pgid = os.getpgid(pid)

        if os.name == "nt":
            self.log.error("Interrupt message not supported on Windows")

        else:
            # Prefer process-group over process
            if pgid and hasattr(os, "killpg"):
                try:
                    os.killpg(pgid, SIGINT)
                    return
                except OSError:
                    pass
            try:
                os.kill(pid, SIGINT)
            except OSError:
                pass

        content = parent['content']
        self.session.send(stream,
                          'interrupt_reply',
                          content,
                          parent,
                          ident=ident)
        return

    async def shutdown_request(self, stream, ident, parent):
        content = self.do_shutdown(parent['content']['restart'])
        if inspect.isawaitable(content):
            content = await content
        self.session.send(stream,
                          'shutdown_reply',
                          content,
                          parent,
                          ident=ident)
        # same content, but different msg_id for broadcasting on IOPub
        self._shutdown_message = self.session.msg('shutdown_reply', content,
                                                  parent)

        self._at_shutdown()

        self.log.debug('Stopping control ioloop')
        control_io_loop = self.control_stream.io_loop
        control_io_loop.add_callback(control_io_loop.stop)

        self.log.debug('Stopping shell ioloop')
        shell_io_loop = self.shell_stream.io_loop
        shell_io_loop.add_callback(shell_io_loop.stop)

    def do_shutdown(self, restart):
        """Override in subclasses to do things when the frontend shuts down the
        kernel.
        """
        return {'status': 'ok', 'restart': restart}

    async def is_complete_request(self, stream, ident, parent):
        content = parent['content']
        code = content['code']

        reply_content = self.do_is_complete(code)
        if inspect.isawaitable(reply_content):
            reply_content = await reply_content
        reply_content = json_clean(reply_content)
        reply_msg = self.session.send(stream, 'is_complete_reply',
                                      reply_content, parent, ident)
        self.log.debug("%s", reply_msg)

    def do_is_complete(self, code):
        """Override in subclasses to find completions.
        """
        return {'status': 'unknown'}

    async def debug_request(self, stream, ident, parent):
        content = parent['content']

        reply_content = self.do_debug_request(content)
        if inspect.isawaitable(reply_content):
            reply_content = await reply_content
        reply_content = json_clean(reply_content)
        reply_msg = self.session.send(stream, 'debug_reply', reply_content,
                                      parent, ident)
        self.log.debug("%s", reply_msg)

    async def do_debug_request(self, msg):
        raise NotImplementedError

    #---------------------------------------------------------------------------
    # Engine methods (DEPRECATED)
    #---------------------------------------------------------------------------

    async def apply_request(self, stream, ident, parent):
        self.log.warning(
            "apply_request is deprecated in kernel_base, moving to ipyparallel."
        )
        try:
            content = parent['content']
            bufs = parent['buffers']
            msg_id = parent['header']['msg_id']
        except Exception:
            self.log.error("Got bad msg: %s", parent, exc_info=True)
            return

        md = self.init_metadata(parent)

        reply_content, result_buf = self.do_apply(content, bufs, msg_id, md)

        # flush i/o
        sys.stdout.flush()
        sys.stderr.flush()

        md = self.finish_metadata(parent, md, reply_content)

        self.session.send(stream,
                          'apply_reply',
                          reply_content,
                          parent=parent,
                          ident=ident,
                          buffers=result_buf,
                          metadata=md)

    def do_apply(self, content, bufs, msg_id, reply_metadata):
        """DEPRECATED"""
        raise NotImplementedError

    #---------------------------------------------------------------------------
    # Control messages (DEPRECATED)
    #---------------------------------------------------------------------------

    async def abort_request(self, stream, ident, parent):
        """abort a specific msg by id"""
        self.log.warning(
            "abort_request is deprecated in kernel_base. It is only part of IPython parallel"
        )
        msg_ids = parent['content'].get('msg_ids', None)
        if isinstance(msg_ids, str):
            msg_ids = [msg_ids]
        if not msg_ids:
            self._abort_queues()
        for mid in msg_ids:
            self.aborted.add(str(mid))

        content = dict(status='ok')
        reply_msg = self.session.send(stream,
                                      'abort_reply',
                                      content=content,
                                      parent=parent,
                                      ident=ident)
        self.log.debug("%s", reply_msg)

    async def clear_request(self, stream, idents, parent):
        """Clear our namespace."""
        self.log.warning(
            "clear_request is deprecated in kernel_base. It is only part of IPython parallel"
        )
        content = self.do_clear()
        self.session.send(stream,
                          'clear_reply',
                          ident=idents,
                          parent=parent,
                          content=content)

    def do_clear(self):
        """DEPRECATED since 4.0.3"""
        raise NotImplementedError

    #---------------------------------------------------------------------------
    # Protected interface
    #---------------------------------------------------------------------------

    def _topic(self, topic):
        """prefixed topic for IOPub messages"""
        base = "kernel.%s" % self.ident

        return ("%s.%s" % (base, topic)).encode()

    _aborting = Bool(False)

    async def _abort_queues(self):
        self.shell_stream.flush()
        self._aborting = True

        def stop_aborting():
            self.log.info("Finishing abort")
            self._aborting = False

        asyncio.get_event_loop().call_later(self.stop_on_error_timeout,
                                            stop_aborting)

    def _send_abort_reply(self, stream, msg, idents):
        """Send a reply to an aborted request"""
        self.log.info(
            f"Aborting {msg['header']['msg_id']}: {msg['header']['msg_type']}")
        reply_type = msg["header"]["msg_type"].rsplit("_", 1)[0] + "_reply"
        status = {"status": "aborted"}
        md = self.init_metadata(msg)
        md = self.finish_metadata(msg, md, status)
        md.update(status)

        self.session.send(
            stream,
            reply_type,
            metadata=md,
            content=status,
            parent=msg,
            ident=idents,
        )

    def _no_raw_input(self):
        """Raise StdinNotImplementedError if active frontend doesn't support
        stdin."""
        raise StdinNotImplementedError("raw_input was called, but this "
                                       "frontend does not support stdin.")

    def getpass(self, prompt='', stream=None):
        """Forward getpass to frontends

        Raises
        ------
        StdinNotImplementedError if active frontend doesn't support stdin.
        """
        if not self._allow_stdin:
            raise StdinNotImplementedError(
                "getpass was called, but this frontend does not support input requests."
            )
        if stream is not None:
            import warnings

            warnings.warn(
                "The `stream` parameter of `getpass.getpass` will have no effect when using yapkernel",
                UserWarning,
                stacklevel=2,
            )
        return self._input_request(
            prompt,
            self._parent_ident["shell"],
            self.get_parent("shell"),
            password=True,
        )

    def raw_input(self, prompt=''):
        """Forward raw_input to frontends

        Raises
        ------
        StdinNotImplementedError if active frontend doesn't support stdin.
        """
        if not self._allow_stdin:
            raise StdinNotImplementedError(
                "raw_input was called, but this frontend does not support input requests."
            )
        return self._input_request(
            str(prompt),
            self._parent_ident["shell"],
            self.get_parent("shell"),
            password=False,
        )

    def _input_request(self, prompt, ident, parent, password=False):
        # Flush output before making the request.
        sys.stderr.flush()
        sys.stdout.flush()

        # flush the stdin socket, to purge stale replies
        while True:
            try:
                self.stdin_socket.recv_multipart(zmq.NOBLOCK)
            except zmq.ZMQError as e:
                if e.errno == zmq.EAGAIN:
                    break
                else:
                    raise

        # Send the input request.
        content = json_clean(dict(prompt=prompt, password=password))
        self.session.send(self.stdin_socket,
                          'input_request',
                          content,
                          parent,
                          ident=ident)

        # Await a response.
        while True:
            try:
                # Use polling with select() so KeyboardInterrupts can get
                # through; doing a blocking recv() means stdin reads are
                # uninterruptible on Windows. We need a timeout because
                # zmq.select() is also uninterruptible, but at least this
                # way reads get noticed immediately and KeyboardInterrupts
                # get noticed fairly quickly by human response time standards.
                rlist, _, xlist = zmq.select([self.stdin_socket], [],
                                             [self.stdin_socket], 0.01)
                if rlist or xlist:
                    ident, reply = self.session.recv(self.stdin_socket)
                    if (ident, reply) != (None, None):
                        break
            except KeyboardInterrupt:
                # re-raise KeyboardInterrupt, to truncate traceback
                raise KeyboardInterrupt("Interrupted by user") from None
            except Exception:
                self.log.warning("Invalid Message:", exc_info=True)

        try:
            value = reply["content"]["value"]
        except Exception:
            self.log.error("Bad input_reply: %s", parent)
            value = ''
        if value == '\x04':
            # EOF
            raise EOFError
        return value

    def _at_shutdown(self):
        """Actions taken at shutdown by the kernel, called by python's atexit.
        """
        if self._shutdown_message is not None:
            self.session.send(self.iopub_socket,
                              self._shutdown_message,
                              ident=self._topic('shutdown'))
            self.log.debug("%s", self._shutdown_message)
        self.control_stream.flush(zmq.POLLOUT)
class GlobusOAuthenticator(OAuthenticator):
    """The Globus OAuthenticator handles both authorization and passing
    transfer tokens to the spawner."""

    login_service = 'Globus'
    logout_handler = GlobusLogoutHandler

    @default("userdata_url")
    def _userdata_url_default(self):
        return "https://auth.globus.org/v2/oauth2/userinfo"

    @default("authorize_url")
    def _authorize_url_default(self):
        return "https://auth.globus.org/v2/oauth2/authorize"

    @default("revocation_url")
    def _revocation_url_default(self):
        return "https://auth.globus.org/v2/oauth2/token/revoke"

    revocation_url = Unicode(help="Globus URL to revoke live tokens.").tag(config=True)

    @default("token_url")
    def _token_url_default(self):
        return "https://auth.globus.org/v2/oauth2/token"

    globus_groups_url = Unicode(help="Globus URL to get list of user's Groups.").tag(
        config=True
    )

    @default("globus_groups_url")
    def _globus_groups_url_default(self):
        return "https://groups.api.globus.org/v2/groups/my_groups"

    identity_provider = Unicode(
        help="""Restrict which institution a user
    can use to login (GlobusID, University of Hogwarts, etc.). This should
    be set in the app at developers.globus.org, but this acts as an additional
    check to prevent unnecessary account creation."""
    ).tag(config=True)

    def _identity_provider_default(self):
        return os.getenv('IDENTITY_PROVIDER', '')

    username_from_email = Bool(
        help="""Create username from email address, not preferred username. If
        an identity provider is specified, email address must be from the same
        domain. Email scope will be set automatically."""
    ).tag(config=True)

    @default("username_from_email")
    def _username_from_email_default(self):
        return False

    exclude_tokens = List(
        help="""Exclude tokens from being passed into user environments
        when they start notebooks, Terminals, etc."""
    ).tag(config=True)

    def _exclude_tokens_default(self):
        return ['auth.globus.org', 'groups.api.globus.org']

    def _scope_default(self):
        scopes = [
            'openid',
            'profile',
            'urn:globus:auth:scope:transfer.api.globus.org:all',
        ]
        if self.allowed_globus_groups or self.admin_globus_groups:
            scopes.append(
                'urn:globus:auth:scope:groups.api.globus.org:view_my_groups_and_memberships'
            )
        if self.username_from_email:
            scopes.append('email')
        return scopes

    globus_local_endpoint = Unicode(
        help="""If Jupyterhub is also a Globus
    endpoint, its endpoint id can be specified here."""
    ).tag(config=True)

    def _globus_local_endpoint_default(self):
        return os.getenv('GLOBUS_LOCAL_ENDPOINT', '')

    revoke_tokens_on_logout = Bool(
        help="""Revoke tokens so they cannot be used again. Single-user servers
        MUST be restarted after logout in order to get a fresh working set of
        tokens."""
    ).tag(config=True)

    def _revoke_tokens_on_logout_default(self):
        return False

    allowed_globus_groups = Set(
        help="""Allow members of defined Globus Groups to access JupyterHub. Users in an
        admin Globus Group are also automatically allowed. Groups are specified with their UUIDs. Setting this will
        add the Globus Groups scope."""
    ).tag(config=True)

    admin_globus_groups = Set(
        help="""Set members of defined Globus Groups as JupyterHub admin users.
        These users are automatically allowed to login to JupyterHub. Groups are specified with
        their UUIDs. Setting this will add the Globus Groups scope."""
    ).tag(config=True)

    async def pre_spawn_start(self, user, spawner):
        """Add tokens to the spawner whenever the spawner starts a notebook.
        This will allow users to create a transfer client:
        globus-sdk-python.readthedocs.io/en/stable/tutorial/#tutorial-step4
        """
        spawner.environment['GLOBUS_LOCAL_ENDPOINT'] = self.globus_local_endpoint
        state = await user.get_auth_state()
        if state:
            globus_data = base64.b64encode(pickle.dumps(state))
            spawner.environment['GLOBUS_DATA'] = globus_data.decode('utf-8')

    async def authenticate(self, handler, data=None):
        """
        Authenticate with globus.org. Usernames (and therefore Jupyterhub
        accounts) will correspond to a Globus User ID, so [email protected]
        will have the 'foouser' account in Jupyterhub.
        """
        # Complete login and exchange the code for tokens.
        params = dict(
            redirect_uri=self.get_callback_url(handler),
            code=handler.get_argument("code"),
            grant_type='authorization_code',
        )
        req = HTTPRequest(
            self.token_url,
            method="POST",
            headers=self.get_client_credential_headers(),
            body=urllib.parse.urlencode(params),
        )
        token_json = await self.fetch(req)

        # Fetch user info at Globus's oauth2/userinfo/ HTTP endpoint to get the username
        user_headers = self.get_default_headers()
        user_headers['Authorization'] = 'Bearer {}'.format(token_json['access_token'])
        req = HTTPRequest(self.userdata_url, method='GET', headers=user_headers)
        user_resp = await self.fetch(req)
        username = self.get_username(user_resp)

        # Each token should have these attributes. Resource server is optional,
        # and likely won't be present.
        token_attrs = [
            'expires_in',
            'resource_server',
            'scope',
            'token_type',
            'refresh_token',
            'access_token',
        ]
        # The Auth Token is a bit special, it comes back at the top level with the
        # id token. The id token has some useful information in it, but nothing that
        # can't be retrieved with an Auth token.
        # Repackage the Auth token into a dict that looks like the other tokens
        auth_token_dict = {
            attr_name: token_json.get(attr_name) for attr_name in token_attrs
        }
        # Make sure only the essentials make it into tokens. Other items, such as 'state' are
        # not needed after authentication and can be discarded.
        other_tokens = [
            {attr_name: token_dict.get(attr_name) for attr_name in token_attrs}
            for token_dict in token_json['other_tokens']
        ]
        tokens = other_tokens + [auth_token_dict]
        # historically, tokens have been organized by resource server for convenience.
        # If multiple scopes are requested from the same resource server, they will be
        # combined into a single token from Globus Auth.
        by_resource_server = {
            token_dict['resource_server']: token_dict
            for token_dict in tokens
            if token_dict['resource_server'] not in self.exclude_tokens
        }

        user_info = {
            'name': username,
            'auth_state': {
                'client_id': self.client_id,
                'tokens': by_resource_server,
            },
        }
        use_globus_groups = False
        user_allowed = False
        if self.allowed_globus_groups or self.admin_globus_groups:
            # If any of these configurations are set, user must be in the allowed or admin Globus Group
            use_globus_groups = True
            user_group_ids = set()
            # Get Groups access token, may not be in dict headed to auth state
            for token_dict in tokens:
                if token_dict['resource_server'] == 'groups.api.globus.org':
                    groups_token = token_dict['access_token']
            # Get list of user's Groups
            groups_headers = self.get_default_headers()
            groups_headers['Authorization'] = 'Bearer {}'.format(groups_token)
            req = HTTPRequest(
                self.globus_groups_url, method='GET', headers=groups_headers
            )
            groups_resp = await self.fetch(req)
            # Build set of Group IDs
            for group in groups_resp:
                user_group_ids.add(group['id'])
            if user_group_ids & self.allowed_globus_groups:
                user_allowed = True
            if self.admin_globus_groups:
                # Admin users are being managed via Globus Groups
                # Default to False
                user_info['admin'] = False
                if user_group_ids & self.admin_globus_groups:
                    # User is an admin, admins allowed by default
                    user_allowed = user_info['admin'] = True

        if user_allowed or not use_globus_groups:
            return user_info
        else:
            self.log.warning('{} not in an allowed Globus Group'.format(username))
            return None

    def get_username(self, user_data):
        # It's possible for identity provider domains to be namespaced
        # https://docs.globus.org/api/auth/specification/#identity_provider_namespaces # noqa
        username_field = 'preferred_username'
        if self.username_from_email:
            username_field = 'email'
        username, domain = user_data.get(username_field).split('@', 1)
        if self.identity_provider and domain != self.identity_provider:
            raise HTTPError(
                403,
                'This site is restricted to {} accounts. Please link your {}'
                ' account at {}.'.format(
                    self.identity_provider,
                    self.identity_provider,
                    'globus.org/app/account',
                ),
            )
        return username

    def get_default_headers(self):
        return {"Accept": "application/json", "User-Agent": "JupyterHub"}

    def get_client_credential_headers(self):
        headers = self.get_default_headers()
        b64key = base64.b64encode(
            bytes("{}:{}".format(self.client_id, self.client_secret), "utf8")
        )
        headers["Authorization"] = "Basic {}".format(b64key.decode("utf8"))
        return headers

    async def revoke_service_tokens(self, services):
        """Revoke live Globus access and refresh tokens. Revoking inert or
        non-existent tokens does nothing. Services are defined by dicts
        returned by tokens.by_resource_server, for example:
        services = { 'transfer.api.globus.org': {'access_token': 'token'}, ...
            <Additional services>...
        }
        """
        access_tokens = [
            token_dict.get('access_token') for token_dict in services.values()
        ]
        refresh_tokens = [
            token_dict.get('refresh_token') for token_dict in services.values()
        ]
        all_tokens = [tok for tok in access_tokens + refresh_tokens if tok is not None]

        for token in all_tokens:
            req = HTTPRequest(
                self.revocation_url,
                method="POST",
                headers=self.get_client_credential_headers(),
                body=urllib.parse.urlencode({'token': token}),
            )
            await self.fetch(req)
Exemple #22
0
class BinderHub(Application):
    """An Application for starting a builder."""
    @default('log_level')
    def _log_level(self):
        return logging.INFO

    aliases = {
        'log-level': 'Application.log_level',
        'f': 'BinderHub.config_file',
        'config': 'BinderHub.config_file',
        'port': 'BinderHub.port',
    }

    flags = {
        'debug': ({
            'BinderHub': {
                'debug': True
            }
        }, "Enable debug HTTP serving & debug logging")
    }

    config_file = Unicode('binderhub_config.py',
                          help="""
        Config file to load.

        If a relative path is provided, it is taken relative to current directory
        """,
                          config=True)

    google_analytics_code = Unicode(None,
                                    allow_none=True,
                                    help="""
        The Google Analytics code to use on the main page.

        Note that we'll respect Do Not Track settings, despite the fact that GA does not.
        We will not load the GA scripts on browsers with DNT enabled.
        """,
                                    config=True)

    google_analytics_domain = Unicode('auto',
                                      help="""
        The Google Analytics domain to use on the main page.

        By default this is set to 'auto', which sets it up for current domain and all
        subdomains. This can be set to a more restrictive domain here for better privacy
        """,
                                      config=True)

    about_message = Unicode('',
                            help="""
        Additional message to display on the about page.

        Will be directly inserted into the about page's source so you can use
        raw HTML.
        """,
                            config=True)

    banner_message = Unicode('',
                             help="""
        Message to display in a banner on all pages.

        The value will be inserted "as is" into a HTML <div> element
        with grey background, located at the top of the BinderHub pages. Raw
        HTML is supported.
        """,
                             config=True)

    extra_footer_scripts = Dict({},
                                help="""
        Extra bits of JavaScript that should be loaded in footer of each page.

        Only the values are set up as scripts. Keys are used only
        for sorting.

        Omit the <script> tag. This should be primarily used for
        analytics code.
        """,
                                config=True)

    base_url = Unicode('/',
                       help="The base URL of the entire application",
                       config=True)

    @validate('base_url')
    def _valid_base_url(self, proposal):
        if not proposal.value.startswith('/'):
            proposal.value = '/' + proposal.value
        if not proposal.value.endswith('/'):
            proposal.value = proposal.value + '/'
        return proposal.value

    badge_base_url = Union(trait_types=[Unicode(), Callable()],
                           help="""
        Base URL to use when generating launch badges.
        Can also be a function that is passed the current handler and returns
        the badge base URL, or "" for the default.

        For example, you could get the badge_base_url from a custom HTTP
        header, the Referer header, or from a request parameter
        """,
                           config=True)

    @default('badge_base_url')
    def _badge_base_url_default(self):
        return ''

    @validate('badge_base_url')
    def _valid_badge_base_url(self, proposal):
        if callable(proposal.value):
            return proposal.value
        # add a trailing slash only when a value is set
        if proposal.value and not proposal.value.endswith('/'):
            proposal.value = proposal.value + '/'
        return proposal.value

    cors_allow_origin = Unicode("",
                                help="""
        Origins that can access the BinderHub API.

        Sets the Access-Control-Allow-Origin header in the spawned
        notebooks. Set to '*' to allow any origin to access spawned
        notebook servers.

        See also BinderSpawner.cors_allow_origin in the binderhub spawner
        mixin for setting this property on the spawned notebooks.
        """,
                                config=True)

    auth_enabled = Bool(False,
                        help="""If JupyterHub authentication enabled,
        require user to login (don't create temporary users during launch) and
        start the new server for the logged in user.""",
                        config=True)

    port = Integer(8585,
                   help="""
        Port for the builder to listen on.
        """,
                   config=True)

    appendix = Unicode(
        help="""
        Appendix to pass to repo2docker

        A multi-line string of Docker directives to run.
        Since the build context cannot be affected,
        ADD will typically not be useful.

        This should be a Python string template.
        It will be formatted with at least the following names available:

        - binder_url: the shareable URL for the current image
          (e.g. for sharing links to the current Binder)
        - repo_url: the repository URL used to build the image
        """,
        config=True,
    )

    sticky_builds = Bool(
        False,
        help="""
        Attempt to assign builds for the same repository to the same node.

        In order to speed up re-builds of a repository all its builds will
        be assigned to the same node in the cluster.

        Note: This feature only works if you also enable docker-in-docker support.
        """,
        config=True,
    )

    use_registry = Bool(True,
                        help="""
        Set to true to push images to a registry & check for images in registry.

        Set to false to use only local docker images. Useful when running
        in a single node.
        """,
                        config=True)

    build_class = Type(Build,
                       help="""
        The class used to build repo2docker images.

        Must inherit from binderhub.build.Build
        """,
                       config=True)

    registry_class = Type(DockerRegistry,
                          help="""
        The class used to Query a Docker registry.

        Must inherit from binderhub.registry.DockerRegistry
        """,
                          config=True)

    per_repo_quota = Integer(
        0,
        help="""
        Maximum number of concurrent users running from a given repo.

        Limits the amount of Binder that can be consumed by a single repo.

        0 (default) means no quotas.
        """,
        config=True,
    )

    pod_quota = Integer(
        None,
        help="""
        The number of concurrent pods this hub has been designed to support.

        This quota is used as an indication for how much above or below the
        design capacity a hub is running.

        Attempts to launch new pods once the quota has been reached will fail.

        The default corresponds to no quota, 0 means the hub can't accept pods
        (maybe because it is in maintenance mode), and any positive integer
        sets the quota.
        """,
        allow_none=True,
        config=True,
    )

    per_repo_quota_higher = Integer(
        0,
        help="""
        Maximum number of concurrent users running from a higher-quota repo.

        Limits the amount of Binder that can be consumed by a single repo. This
        quota is a second limit for repos with special status. See the
        `high_quota_specs` parameter of RepoProvider classes for usage.

        0 (default) means no quotas.
        """,
        config=True,
    )

    log_tail_lines = Integer(
        100,
        help="""
        Limit number of log lines to show when connecting to an already running build.
        """,
        config=True,
    )

    push_secret = Unicode('binder-build-docker-config',
                          allow_none=True,
                          help="""
        A kubernetes secret object that provides credentials for pushing built images.
        """,
                          config=True)

    image_prefix = Unicode("",
                           help="""
        Prefix for all built docker images.

        If you are pushing to gcr.io, this would start with:
            gcr.io/<your-project-name>/

        Set according to whatever registry you are pushing to.

        Defaults to "", which is probably not what you want :)
        """,
                           config=True)

    build_memory_request = ByteSpecification(
        0,
        help="""
        Amount of memory to request when scheduling a build

        0 reserves no memory.

        This is used as the request for the pod that is spawned to do the building,
        even though the pod itself will not be using that much memory
        since the docker building is happening outside the pod.
        However, it makes kubernetes aware of the resources being used,
        and lets it schedule more intelligently.
        """,
        config=True,
    )
    build_memory_limit = ByteSpecification(
        0,
        help="""
        Max amount of memory allocated for each image build process.

        0 sets no limit.

        This is applied to the docker build itself via repo2docker,
        though it is also applied to our pod that submits the build,
        even though that pod will rarely consume much memory.
        Still, it makes it easier to see the resource limits in place via kubernetes.
        """,
        config=True,
    )

    debug = Bool(False,
                 help="""
        Turn on debugging.
        """,
                 config=True)

    build_docker_host = Unicode("/var/run/docker.sock",
                                config=True,
                                help="""
        The docker URL repo2docker should use to build the images.

        Currently, only paths are supported, and they are expected to be available on
        all the hosts.
        """)

    @validate('build_docker_host')
    def docker_build_host_validate(self, proposal):
        parts = urlparse(proposal.value)
        if parts.scheme != 'unix' or parts.netloc != '':
            raise TraitError(
                "Only unix domain sockets on same node are supported for build_docker_host"
            )
        return proposal.value

    build_docker_config = Dict(None,
                               allow_none=True,
                               help="""
        A dict which will be merged into the .docker/config.json of the build container (repo2docker)
        Here, you could for example pass proxy settings as described here:
        https://docs.docker.com/network/proxy/#configure-the-docker-client

        Note: if you provide your own push_secret, this values wont
        have an effect, as the push_secrets will overwrite
        .docker/config.json
        In this case, make sure that you include your config in your push_secret
        """,
                               config=True)

    hub_api_token = Unicode(
        help="""API token for talking to the JupyterHub API""",
        config=True,
    )

    @default('hub_api_token')
    def _default_hub_token(self):
        return os.environ.get('JUPYTERHUB_API_TOKEN', '')

    hub_url = Unicode(
        help="""
        The base URL of the JupyterHub instance where users will run.

        e.g. https://hub.mybinder.org/
        """,
        config=True,
    )

    hub_url_local = Unicode(
        help="""
        The base URL of the JupyterHub instance for local/internal traffic

        If local/internal network connections from the BinderHub process should access
        JupyterHub using a different URL than public/external traffic set this, default
        is hub_url
        """,
        config=True,
    )

    @default('hub_url_local')
    def _default_hub_url_local(self):
        return self.hub_url

    @validate('hub_url', 'hub_url_local')
    def _add_slash(self, proposal):
        """trait validator to ensure hub_url ends with a trailing slash"""
        if proposal.value is not None and not proposal.value.endswith('/'):
            return proposal.value + '/'
        return proposal.value

    build_namespace = Unicode(help="""
        Kubernetes namespace to spawn build pods in.

        Note that the push_secret must refer to a secret in this namespace.
        """,
                              config=True)

    @default('build_namespace')
    def _default_build_namespace(self):
        return os.environ.get('BUILD_NAMESPACE', 'default')

    build_image = Unicode('quay.io/jupyterhub/repo2docker:2021.08.0',
                          help="""
        The repo2docker image to be used for doing builds
        """,
                          config=True)

    build_node_selector = Dict({},
                               config=True,
                               help="""
        Select the node where build pod runs on.
        """)

    repo_providers = Dict(
        {
            'gh': GitHubRepoProvider,
            'gist': GistRepoProvider,
            'git': GitRepoProvider,
            'gl': GitLabRepoProvider,
            'zenodo': ZenodoProvider,
            'figshare': FigshareProvider,
            'hydroshare': HydroshareProvider,
            'dataverse': DataverseProvider,
        },
        config=True,
        help="""
        List of Repo Providers to register and try
        """)

    @validate('repo_providers')
    def _validate_repo_providers(self, proposal):
        """trait validator to ensure there is at least one repo provider"""
        if not proposal.value:
            raise TraitError("Please provide at least one repo provider")

        if any([
                not issubclass(provider, RepoProvider)
                for provider in proposal.value.values()
        ]):
            raise TraitError(
                "Repository providers should inherit from 'binderhub.RepoProvider'"
            )

        return proposal.value

    concurrent_build_limit = Integer(
        32, config=True, help="""The number of concurrent builds to allow.""")
    executor_threads = Integer(
        5,
        config=True,
        help="""The number of threads to use for blocking calls

        Should generally be a small number because we don't
        care about high concurrency here, just not blocking the webserver.
        This executor is not used for long-running tasks (e.g. builds).
        """,
    )
    build_cleanup_interval = Integer(
        60,
        config=True,
        help=
        """Interval (in seconds) for how often stopped build pods will be deleted."""
    )
    build_max_age = Integer(3600 * 4,
                            config=True,
                            help="""Maximum age of builds

        Builds that are still running longer than this
        will be killed.
        """)

    build_token_check_origin = Bool(
        True,
        config=True,
        help="""Whether to validate build token origin.

        False disables the origin check.
        """)

    build_token_expires_seconds = Integer(
        300,
        config=True,
        help="""Expiry (in seconds) of build tokens

        These are generally only used to authenticate a single request
        from a page, so should be short-lived.
        """,
    )

    build_token_secret = Union(
        [Unicode(), Bytes()],
        config=True,
        help="""Secret used to sign build tokens

        Lightweight validation of same-origin requests
        """,
    )

    @validate("build_token_secret")
    def _validate_build_token_secret(self, proposal):
        if isinstance(proposal.value, str):
            # allow hex string for text-only input formats
            return a2b_hex(proposal.value)
        return proposal.value

    @default("build_token_secret")
    def _default_build_token_secret(self):
        if os.environ.get("BINDERHUB_BUILD_TOKEN_SECRET"):
            return a2b_hex(os.environ["BINDERHUB_BUILD_TOKEN_SECRET"])
        app_log.warning(
            "Generating random build token secret."
            " Set BinderHub.build_token_secret to avoid this warning.")
        return secrets.token_bytes(32)

    # FIXME: Come up with a better name for it?
    builder_required = Bool(True,
                            config=True,
                            help="""
        If binderhub should try to continue to run without a working build infrastructure.

        Build infrastructure is kubernetes cluster + docker. This is useful for pure HTML/CSS/JS local development.
        """)

    ban_networks = Dict(
        config=True,
        help="""
        Dict of networks from which requests should be rejected with 403

        Keys are CIDR notation (e.g. '1.2.3.4/32'),
        values are a label used in log / error messages.
        CIDR strings will be parsed with `ipaddress.ip_network()`.
        """,
    )

    @validate("ban_networks")
    def _cast_ban_networks(self, proposal):
        """Cast CIDR strings to IPv[4|6]Network objects"""
        networks = {}
        for cidr, message in proposal.value.items():
            networks[ipaddress.ip_network(cidr)] = message

        return networks

    ban_networks_min_prefix_len = Integer(
        1,
        help="The shortest prefix in ban_networks",
    )

    @observe("ban_networks")
    def _update_prefix_len(self, change):
        if not change.new:
            min_len = 1
        else:
            min_len = min(net.prefixlen for net in change.new)
        self.ban_networks_min_prefix_len = min_len or 1

    tornado_settings = Dict(config=True,
                            help="""
        additional settings to pass through to tornado.

        can include things like additional headers, etc.
        """)

    template_variables = Dict(
        config=True,
        help="Extra variables to supply to jinja templates when rendering.",
    )

    template_path = Unicode(
        help=
        "Path to search for custom jinja templates, before using the default templates.",
        config=True,
    )

    @default('template_path')
    def _template_path_default(self):
        return os.path.join(HERE, 'templates')

    extra_static_path = Unicode(
        help='Path to search for extra static files.',
        config=True,
    )

    extra_static_url_prefix = Unicode(
        '/extra_static/',
        help='Url prefix to serve extra static files.',
        config=True,
    )

    normalized_origin = Unicode(
        '',
        config=True,
        help=
        'Origin to use when emitting events. Defaults to hostname of request when empty'
    )

    allowed_metrics_ips = Set(
        help=
        'List of IPs or networks allowed to GET /metrics. Defaults to all.',
        config=True)

    @staticmethod
    def add_url_prefix(prefix, handlers):
        """add a url prefix to handlers"""
        for i, tup in enumerate(handlers):
            lis = list(tup)
            lis[0] = url_path_join(prefix, tup[0])
            handlers[i] = tuple(lis)
        return handlers

    def init_pycurl(self):
        try:
            AsyncHTTPClient.configure(
                "tornado.curl_httpclient.CurlAsyncHTTPClient")
        except ImportError as e:
            self.log.debug(
                "Could not load pycurl: %s\npycurl is recommended if you have a large number of users.",
                e)
        # set max verbosity of curl_httpclient at INFO
        # because debug-logging from curl_httpclient
        # includes every full request and response
        if self.log_level < logging.INFO:
            curl_log = logging.getLogger('tornado.curl_httpclient')
            curl_log.setLevel(logging.INFO)

    def initialize(self, *args, **kwargs):
        """Load configuration settings."""
        super().initialize(*args, **kwargs)
        self.load_config_file(self.config_file)
        # hook up tornado logging
        if self.debug:
            self.log_level = logging.DEBUG
        tornado.options.options.logging = logging.getLevelName(self.log_level)
        tornado.log.enable_pretty_logging()
        self.log = tornado.log.app_log

        self.init_pycurl()

        # initialize kubernetes config
        if self.builder_required:
            try:
                kubernetes.config.load_incluster_config()
            except kubernetes.config.ConfigException:
                kubernetes.config.load_kube_config()
            self.tornado_settings[
                "kubernetes_client"] = self.kube_client = kubernetes.client.CoreV1Api(
                )

        # times 2 for log + build threads
        self.build_pool = ThreadPoolExecutor(self.concurrent_build_limit * 2)
        # default executor for asyncifying blocking calls (e.g. to kubernetes, docker).
        # this should not be used for long-running requests
        self.executor = ThreadPoolExecutor(self.executor_threads)

        jinja_options = dict(autoescape=True, )
        template_paths = [self.template_path]
        base_template_path = self._template_path_default()
        if base_template_path not in template_paths:
            # add base templates to the end, so they are looked up at last after custom templates
            template_paths.append(base_template_path)
        loader = ChoiceLoader([
            # first load base templates with prefix
            PrefixLoader({'templates': FileSystemLoader([base_template_path])},
                         '/'),
            # load all templates
            FileSystemLoader(template_paths)
        ])
        jinja_env = Environment(loader=loader, **jinja_options)
        if self.use_registry:
            registry = self.registry_class(parent=self)
        else:
            registry = None

        self.launcher = Launcher(
            parent=self,
            hub_url=self.hub_url,
            hub_url_local=self.hub_url_local,
            hub_api_token=self.hub_api_token,
            create_user=not self.auth_enabled,
        )

        self.event_log = EventLog(parent=self)

        for schema_file in glob(os.path.join(HERE, 'event-schemas', '*.json')):
            with open(schema_file) as f:
                self.event_log.register_schema(json.load(f))

        self.tornado_settings.update({
            "log_function":
            log_request,
            "push_secret":
            self.push_secret,
            "image_prefix":
            self.image_prefix,
            "debug":
            self.debug,
            "hub_url":
            self.hub_url,
            "launcher":
            self.launcher,
            "appendix":
            self.appendix,
            "ban_networks":
            self.ban_networks,
            "ban_networks_min_prefix_len":
            self.ban_networks_min_prefix_len,
            "build_namespace":
            self.build_namespace,
            "build_image":
            self.build_image,
            "build_node_selector":
            self.build_node_selector,
            "build_pool":
            self.build_pool,
            "build_token_check_origin":
            self.build_token_check_origin,
            "build_token_secret":
            self.build_token_secret,
            "build_token_expires_seconds":
            self.build_token_expires_seconds,
            "sticky_builds":
            self.sticky_builds,
            "log_tail_lines":
            self.log_tail_lines,
            "pod_quota":
            self.pod_quota,
            "per_repo_quota":
            self.per_repo_quota,
            "per_repo_quota_higher":
            self.per_repo_quota_higher,
            "repo_providers":
            self.repo_providers,
            "rate_limiter":
            RateLimiter(parent=self),
            "use_registry":
            self.use_registry,
            "build_class":
            self.build_class,
            "registry":
            registry,
            "traitlets_config":
            self.config,
            "google_analytics_code":
            self.google_analytics_code,
            "google_analytics_domain":
            self.google_analytics_domain,
            "about_message":
            self.about_message,
            "banner_message":
            self.banner_message,
            "extra_footer_scripts":
            self.extra_footer_scripts,
            "jinja2_env":
            jinja_env,
            "build_memory_limit":
            self.build_memory_limit,
            "build_memory_request":
            self.build_memory_request,
            "build_docker_host":
            self.build_docker_host,
            "build_docker_config":
            self.build_docker_config,
            "base_url":
            self.base_url,
            "badge_base_url":
            self.badge_base_url,
            "static_path":
            os.path.join(HERE, "static"),
            "static_url_prefix":
            url_path_join(self.base_url, "static/"),
            "template_variables":
            self.template_variables,
            "executor":
            self.executor,
            "auth_enabled":
            self.auth_enabled,
            "event_log":
            self.event_log,
            "normalized_origin":
            self.normalized_origin,
            "allowed_metrics_ips":
            set(map(ipaddress.ip_network, self.allowed_metrics_ips))
        })
        if self.auth_enabled:
            self.tornado_settings['cookie_secret'] = os.urandom(32)
        if self.cors_allow_origin:
            self.tornado_settings.setdefault(
                'headers',
                {})['Access-Control-Allow-Origin'] = self.cors_allow_origin

        handlers = [
            (r'/metrics', MetricsHandler),
            (r'/versions', VersionHandler),
            (r"/build/([^/]+)/(.+)", BuildHandler),
            (r"/v2/([^/]+)/(.+)", ParameterizedMainHandler),
            (r"/repo/([^/]+)/([^/]+)(/.*)?", LegacyRedirectHandler),
            (r'/~([^/]+/.*)', UserRedirectHandler),
            # for backward-compatible mybinder.org badge URLs
            # /assets/images/badge.svg
            (r'/assets/(images/badge\.svg)', tornado.web.StaticFileHandler, {
                'path': self.tornado_settings['static_path']
            }),
            # /badge.svg
            (r'/(badge\.svg)', tornado.web.StaticFileHandler, {
                'path':
                os.path.join(self.tornado_settings['static_path'], 'images')
            }),
            # /badge_logo.svg
            (r'/(badge\_logo\.svg)', tornado.web.StaticFileHandler, {
                'path':
                os.path.join(self.tornado_settings['static_path'], 'images')
            }),
            # /logo_social.png
            (r'/(logo\_social\.png)', tornado.web.StaticFileHandler, {
                'path':
                os.path.join(self.tornado_settings['static_path'], 'images')
            }),
            # /favicon_XXX.ico
            (r'/(favicon\_fail\.ico)', tornado.web.StaticFileHandler, {
                'path':
                os.path.join(self.tornado_settings['static_path'], 'images')
            }),
            (r'/(favicon\_success\.ico)', tornado.web.StaticFileHandler, {
                'path':
                os.path.join(self.tornado_settings['static_path'], 'images')
            }),
            (r'/(favicon\_building\.ico)', tornado.web.StaticFileHandler, {
                'path':
                os.path.join(self.tornado_settings['static_path'], 'images')
            }),
            (r'/about', AboutHandler),
            (r'/health', HealthHandler, {
                'hub_url': self.hub_url_local
            }),
            (r'/_config', ConfigHandler),
            (r'/', MainHandler),
            (r'.*', Custom404),
        ]
        handlers = self.add_url_prefix(self.base_url, handlers)
        if self.extra_static_path:
            handlers.insert(-1, (re.escape(
                url_path_join(self.base_url, self.extra_static_url_prefix)) +
                                 r"(.*)", tornado.web.StaticFileHandler, {
                                     'path': self.extra_static_path
                                 }))
        if self.auth_enabled:
            oauth_redirect_uri = os.getenv('JUPYTERHUB_OAUTH_CALLBACK_URL') or \
                                 url_path_join(self.base_url, 'oauth_callback')
            oauth_redirect_uri = urlparse(oauth_redirect_uri).path
            handlers.insert(
                -1, (re.escape(oauth_redirect_uri), HubOAuthCallbackHandler))
        self.tornado_app = tornado.web.Application(handlers,
                                                   **self.tornado_settings)

    def stop(self):
        self.http_server.stop()
        self.build_pool.shutdown()

    async def watch_build_pods(self):
        """Watch build pods

        Every build_cleanup_interval:
        - delete stopped build pods
        - delete running build pods older than build_max_age
        """
        while True:
            try:
                await asyncio.wrap_future(
                    self.executor.submit(lambda: Build.cleanup_builds(
                        self.kube_client,
                        self.build_namespace,
                        self.build_max_age,
                    )))
            except Exception:
                app_log.exception("Failed to cleanup build pods")
            await asyncio.sleep(self.build_cleanup_interval)

    def start(self, run_loop=True):
        self.log.info("BinderHub starting on port %i", self.port)
        self.http_server = HTTPServer(
            self.tornado_app,
            xheaders=True,
        )
        self.http_server.listen(self.port)
        if self.builder_required:
            asyncio.ensure_future(self.watch_build_pods())
        if run_loop:
            tornado.ioloop.IOLoop.current().start()
class EnterpriseGatewayConfigMixin(Configurable):
    # Server IP / PORT binding
    port_env = 'EG_PORT'
    port_default_value = 8888
    port = Integer(port_default_value, config=True,
                   help='Port on which to listen (EG_PORT env var)')

    @default('port')
    def port_default(self):
        return int(os.getenv(self.port_env, os.getenv('KG_PORT', self.port_default_value)))

    port_retries_env = 'EG_PORT_RETRIES'
    port_retries_default_value = 50
    port_retries = Integer(port_retries_default_value, config=True,
                           help="""Number of ports to try if the specified port is not available
                           (EG_PORT_RETRIES env var)""")

    @default('port_retries')
    def port_retries_default(self):
        return int(os.getenv(self.port_retries_env, os.getenv('KG_PORT_RETRIES', self.port_retries_default_value)))

    ip_env = 'EG_IP'
    ip_default_value = '127.0.0.1'
    ip = Unicode(ip_default_value, config=True,
                 help='IP address on which to listen (EG_IP env var)')

    @default('ip')
    def ip_default(self):
        return os.getenv(self.ip_env, os.getenv('KG_IP', self.ip_default_value))

    # Base URL
    base_url_env = 'EG_BASE_URL'
    base_url_default_value = '/'
    base_url = Unicode(base_url_default_value, config=True,
                       help='The base path for mounting all API resources (EG_BASE_URL env var)')

    @default('base_url')
    def base_url_default(self):
        return os.getenv(self.base_url_env, os.getenv('KG_BASE_URL', self.base_url_default_value))

    # Token authorization
    auth_token_env = 'EG_AUTH_TOKEN'
    auth_token = Unicode(config=True,
                         help='Authorization token required for all requests (EG_AUTH_TOKEN env var)')

    @default('auth_token')
    def _auth_token_default(self):
        return os.getenv(self.auth_token_env, os.getenv('KG_AUTH_TOKEN', ''))

    # Begin CORS headers
    allow_credentials_env = 'EG_ALLOW_CREDENTIALS'
    allow_credentials = Unicode(config=True,
                                help='Sets the Access-Control-Allow-Credentials header. (EG_ALLOW_CREDENTIALS env var)')

    @default('allow_credentials')
    def allow_credentials_default(self):
        return os.getenv(self.allow_credentials_env, os.getenv('KG_ALLOW_CREDENTIALS', ''))

    allow_headers_env = 'EG_ALLOW_HEADERS'
    allow_headers = Unicode(config=True,
                            help='Sets the Access-Control-Allow-Headers header. (EG_ALLOW_HEADERS env var)')

    @default('allow_headers')
    def allow_headers_default(self):
        return os.getenv(self.allow_headers_env, os.getenv('KG_ALLOW_HEADERS', ''))

    allow_methods_env = 'EG_ALLOW_METHODS'
    allow_methods = Unicode(config=True,
                            help='Sets the Access-Control-Allow-Methods header. (EG_ALLOW_METHODS env var)')

    @default('allow_methods')
    def allow_methods_default(self):
        return os.getenv(self.allow_methods_env, os.getenv('KG_ALLOW_METHODS', ''))

    allow_origin_env = 'EG_ALLOW_ORIGIN'
    allow_origin = Unicode(config=True,
                           help='Sets the Access-Control-Allow-Origin header. (EG_ALLOW_ORIGIN env var)')

    @default('allow_origin')
    def allow_origin_default(self):
        return os.getenv(self.allow_origin_env, os.getenv('KG_ALLOW_ORIGIN', ''))

    expose_headers_env = 'EG_EXPOSE_HEADERS'
    expose_headers = Unicode(config=True,
                             help='Sets the Access-Control-Expose-Headers header. (EG_EXPOSE_HEADERS env var)')

    @default('expose_headers')
    def expose_headers_default(self):
        return os.getenv(self.expose_headers_env, os.getenv('KG_EXPOSE_HEADERS', ''))

    trust_xheaders_env = 'EG_TRUST_XHEADERS'
    trust_xheaders = CBool(False, config=True,
                           help="""Use x-* header values for overriding the remote-ip, useful when
                           application is behing a proxy. (EG_TRUST_XHEADERS env var)""")

    @default('trust_xheaders')
    def trust_xheaders_default(self):
        return strtobool(os.getenv(self.trust_xheaders_env, os.getenv('KG_TRUST_XHEADERS', 'False')))

    certfile_env = 'EG_CERTFILE'
    certfile = Unicode(None, config=True, allow_none=True,
                       help='The full path to an SSL/TLS certificate file. (EG_CERTFILE env var)')

    @default('certfile')
    def certfile_default(self):
        return os.getenv(self.certfile_env, os.getenv('KG_CERTFILE'))

    keyfile_env = 'EG_KEYFILE'
    keyfile = Unicode(None, config=True, allow_none=True,
                      help='The full path to a private key file for usage with SSL/TLS. (EG_KEYFILE env var)')

    @default('keyfile')
    def keyfile_default(self):
        return os.getenv(self.keyfile_env, os.getenv('KG_KEYFILE'))

    client_ca_env = 'EG_CLIENT_CA'
    client_ca = Unicode(None, config=True, allow_none=True,
                        help="""The full path to a certificate authority certificate for SSL/TLS
                        client authentication. (EG_CLIENT_CA env var)""")

    @default('client_ca')
    def client_ca_default(self):
        return os.getenv(self.client_ca_env, os.getenv('KG_CLIENT_CA'))

    ssl_version_env = 'EG_SSL_VERSION'
    ssl_version_default_value = ssl.PROTOCOL_TLSv1_2
    ssl_version = Integer(None, config=True, allow_none=True,
                          help="""Sets the SSL version to use for the web socket
                          connection. (EG_SSL_VERSION env var)""")

    @default('ssl_version')
    def ssl_version_default(self):
        ssl_from_env = os.getenv(self.ssl_version_env, os.getenv('KG_SSL_VERSION'))
        return ssl_from_env if ssl_from_env is None else int(ssl_from_env)

    max_age_env = 'EG_MAX_AGE'
    max_age = Unicode(config=True,
                      help='Sets the Access-Control-Max-Age header. (EG_MAX_AGE env var)')

    @default('max_age')
    def max_age_default(self):
        return os.getenv(self.max_age_env, os.getenv('KG_MAX_AGE', ''))
    # End CORS headers

    max_kernels_env = 'EG_MAX_KERNELS'
    max_kernels = Integer(None, config=True,
                          allow_none=True,
                          help="""Limits the number of kernel instances allowed to run by this gateway.
                          Unbounded by default. (EG_MAX_KERNELS env var)""")

    @default('max_kernels')
    def max_kernels_default(self):
        val = os.getenv(self.max_kernels_env, os.getenv('KG_MAX_KERNELS'))
        return val if val is None else int(val)

    default_kernel_name_env = 'EG_DEFAULT_KERNEL_NAME'
    default_kernel_name = Unicode(config=True,
                                  help='Default kernel name when spawning a kernel (EG_DEFAULT_KERNEL_NAME env var)')

    @default('default_kernel_name')
    def default_kernel_name_default(self):
        # defaults to Jupyter's default kernel name on empty string
        return os.getenv(self.default_kernel_name_env, os.getenv('KG_DEFAULT_KERNEL_NAME', ''))

    list_kernels_env = 'EG_LIST_KERNELS'
    list_kernels = Bool(config=True,
                        help="""Permits listing of the running kernels using API endpoints /api/kernels
                        and /api/sessions. (EG_LIST_KERNELS env var) Note: Jupyter Notebook
                        allows this by default but Jupyter Enterprise Gateway does not.""")

    @default('list_kernels')
    def list_kernels_default(self):
        return os.getenv(self.list_kernels_env, os.getenv('KG_LIST_KERNELS', 'False')).lower() == 'true'

    env_whitelist_env = 'EG_ENV_WHITELIST'
    env_whitelist = List(config=True,
                         help="""Environment variables allowed to be set when a client requests a
                         new kernel. (EG_ENV_WHITELIST env var)""")

    @default('env_whitelist')
    def env_whitelist_default(self):
        return os.getenv(self.env_whitelist_env, os.getenv('KG_ENV_WHITELIST', '')).split(',')

    env_process_whitelist_env = 'EG_ENV_PROCESS_WHITELIST'
    env_process_whitelist = List(config=True,
                                 help="""Environment variables allowed to be inherited
                                 from the spawning process by the kernel. (EG_ENV_PROCESS_WHITELIST env var)""")

    @default('env_process_whitelist')
    def env_process_whitelist_default(self):
        return os.getenv(self.env_process_whitelist_env, os.getenv('KG_ENV_PROCESS_WHITELIST', '')).split(',')

    kernel_headers_env = 'EG_KERNEL_HEADERS'
    kernel_headers = List(config=True,
                          help="""Request headers to make available to kernel launch framework.
                          (EG_KERNEL_HEADERS env var)""")

    @default('kernel_headers')
    def kernel_headers_default(self):
        default_headers = os.getenv(self.kernel_headers_env)
        return default_headers.split(',') if default_headers else []

    # Remote hosts
    remote_hosts_env = 'EG_REMOTE_HOSTS'
    remote_hosts_default_value = 'localhost'
    remote_hosts = List(default_value=[remote_hosts_default_value], config=True,
                        help="""Bracketed comma-separated list of hosts on which DistributedProcessProxy
                        kernels will be launched e.g., ['host1','host2']. (EG_REMOTE_HOSTS env var
                        - non-bracketed, just comma-separated)""")

    @default('remote_hosts')
    def remote_hosts_default(self):
        return os.getenv(self.remote_hosts_env, self.remote_hosts_default_value).split(',')

    # Yarn endpoint
    yarn_endpoint_env = 'EG_YARN_ENDPOINT'
    yarn_endpoint = Unicode(None, config=True, allow_none=True,
                            help="""The http url specifying the YARN Resource Manager. Note: If this value is NOT set,
                            the YARN library will use the files within the local HADOOP_CONFIG_DIR to determine the
                            active resource manager. (EG_YARN_ENDPOINT env var)""")

    @default('yarn_endpoint')
    def yarn_endpoint_default(self):
        return os.getenv(self.yarn_endpoint_env)

    # Alt Yarn endpoint
    alt_yarn_endpoint_env = 'EG_ALT_YARN_ENDPOINT'
    alt_yarn_endpoint = Unicode(None, config=True, allow_none=True,
                                help="""The http url specifying the alternate YARN Resource Manager.  This value should
                                be set when YARN Resource Managers are configured for high availability.  Note: If both
                                YARN endpoints are NOT set, the YARN library will use the files within the local
                                HADOOP_CONFIG_DIR to determine the active resource manager.
                                (EG_ALT_YARN_ENDPOINT env var)""")

    @default('alt_yarn_endpoint')
    def alt_yarn_endpoint_default(self):
        return os.getenv(self.alt_yarn_endpoint_env)

    yarn_endpoint_security_enabled_env = 'EG_YARN_ENDPOINT_SECURITY_ENABLED'
    yarn_endpoint_security_enabled_default_value = False
    yarn_endpoint_security_enabled = Bool(yarn_endpoint_security_enabled_default_value, config=True,
                                          help="""Is YARN Kerberos/SPNEGO Security enabled (True/False).
                                          (EG_YARN_ENDPOINT_SECURITY_ENABLED env var)""")

    @default('yarn_endpoint_security_enabled')
    def yarn_endpoint_security_enabled_default(self):
        return bool(os.getenv(self.yarn_endpoint_security_enabled_env,
                              self.yarn_endpoint_security_enabled_default_value))

    # Conductor endpoint
    conductor_endpoint_env = 'EG_CONDUCTOR_ENDPOINT'
    conductor_endpoint_default_value = None
    conductor_endpoint = Unicode(conductor_endpoint_default_value,
                                 allow_none=True,
                                 config=True,
                                 help="""The http url for accessing the Conductor REST API.
                                 (EG_CONDUCTOR_ENDPOINT env var)""")

    @default('conductor_endpoint')
    def conductor_endpoint_default(self):
        return os.getenv(self.conductor_endpoint_env, self.conductor_endpoint_default_value)

    _log_formatter_cls = LogFormatter  # traitlet default is LevelFormatter

    @default('log_format')
    def _default_log_format(self):
        """override default log format to include milliseconds"""
        return u"%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s]%(end_color)s %(message)s"

    # Impersonation enabled
    impersonation_enabled_env = 'EG_IMPERSONATION_ENABLED'
    impersonation_enabled = Bool(False, config=True,
                                 help="""Indicates whether impersonation will be performed during kernel launch.
                                 (EG_IMPERSONATION_ENABLED env var)""")

    @default('impersonation_enabled')
    def impersonation_enabled_default(self):
        return bool(os.getenv(self.impersonation_enabled_env, 'false').lower() == 'true')

    # Unauthorized users
    unauthorized_users_env = 'EG_UNAUTHORIZED_USERS'
    unauthorized_users_default_value = 'root'
    unauthorized_users = Set(default_value={unauthorized_users_default_value}, config=True,
                             help="""Comma-separated list of user names (e.g., ['root','admin']) against which
                             KERNEL_USERNAME will be compared.  Any match (case-sensitive) will prevent the
                             kernel's launch and result in an HTTP 403 (Forbidden) error.
                             (EG_UNAUTHORIZED_USERS env var - non-bracketed, just comma-separated)""")

    @default('unauthorized_users')
    def unauthorized_users_default(self):
        return os.getenv(self.unauthorized_users_env, self.unauthorized_users_default_value).split(',')

    # Authorized users
    authorized_users_env = 'EG_AUTHORIZED_USERS'
    authorized_users = Set(config=True,
                           help="""Comma-separated list of user names (e.g., ['bob','alice']) against which
                           KERNEL_USERNAME will be compared.  Any match (case-sensitive) will allow the kernel's
                           launch, otherwise an HTTP 403 (Forbidden) error will be raised.  The set of unauthorized
                           users takes precedence. This option should be used carefully as it can dramatically limit
                           who can launch kernels.  (EG_AUTHORIZED_USERS env var - non-bracketed,
                           just comma-separated)""")

    @default('authorized_users')
    def authorized_users_default(self):
        au_env = os.getenv(self.authorized_users_env)
        return au_env.split(',') if au_env is not None else []

    # Port range
    port_range_env = 'EG_PORT_RANGE'
    port_range_default_value = "0..0"
    port_range = Unicode(port_range_default_value, config=True,
                         help="""Specifies the lower and upper port numbers from which ports are created.
                         The bounded values are separated by '..' (e.g., 33245..34245 specifies a range of 1000 ports
                         to be randomly selected). A range of zero (e.g., 33245..33245 or 0..0) disables port-range
                         enforcement.  (EG_PORT_RANGE env var)""")

    @default('port_range')
    def port_range_default(self):
        return os.getenv(self.port_range_env, self.port_range_default_value)

    # Max Kernels per User
    max_kernels_per_user_env = 'EG_MAX_KERNELS_PER_USER'
    max_kernels_per_user_default_value = -1
    max_kernels_per_user = Integer(max_kernels_per_user_default_value, config=True,
                                   help="""Specifies the maximum number of kernels a user can have active
                                   simultaneously.  A value of -1 disables enforcement.
                                   (EG_MAX_KERNELS_PER_USER env var)""")

    @default('max_kernels_per_user')
    def max_kernels_per_user_default(self):
        return int(os.getenv(self.max_kernels_per_user_env, self.max_kernels_per_user_default_value))

    ws_ping_interval_env = 'EG_WS_PING_INTERVAL_SECS'
    ws_ping_interval_default_value = 30
    ws_ping_interval = Integer(ws_ping_interval_default_value, config=True,
                               help="""Specifies the ping interval(in seconds) that should be used by zmq port
                                     associated withspawned kernels.Set this variable to 0 to disable ping mechanism.
                                    (EG_WS_PING_INTERVAL_SECS env var)""")

    @default('ws_ping_interval')
    def ws_ping_interval_default(self):
        return int(os.getenv(self.ws_ping_interval_env, self.ws_ping_interval_default_value))

    # Dynamic Update Interval
    dynamic_config_interval_env = 'EG_DYNAMIC_CONFIG_INTERVAL'
    dynamic_config_interval_default_value = 0
    dynamic_config_interval = Integer(dynamic_config_interval_default_value, min=0, config=True,
                                      help="""Specifies the number of seconds configuration files are polled for
                                      changes.  A value of 0 or less disables dynamic config updates.
                                      (EG_DYNAMIC_CONFIG_INTERVAL env var)""")

    @default('dynamic_config_interval')
    def dynamic_config_interval_default(self):
        return int(os.getenv(self.dynamic_config_interval_env, self.dynamic_config_interval_default_value))

    @observe('dynamic_config_interval')
    def dynamic_config_interval_changed(self, event):
        prev_val = event['old']
        self.dynamic_config_interval = event['new']
        if self.dynamic_config_interval != prev_val:
            # Values are different.  Stop the current poller.  If new value is > 0, start a poller.
            if self.dynamic_config_poller:
                self.dynamic_config_poller.stop()
                self.dynamic_config_poller = None

            if self.dynamic_config_interval <= 0:
                self.log.warning("Dynamic configuration updates have been disabled and cannot be re-enabled "
                                 "without restarting Enterprise Gateway!")
            # The interval has been changed, but still positive
            elif prev_val > 0 and hasattr(self, "init_dynamic_configs"):
                self.init_dynamic_configs()  # Restart the poller

    dynamic_config_poller = None

    kernel_spec_manager = Instance("jupyter_client.kernelspec.KernelSpecManager", allow_none=True)

    kernel_spec_manager_class = Type(
        default_value="jupyter_client.kernelspec.KernelSpecManager",
        config=True,
        help="""
        The kernel spec manager class to use. Must be a subclass
        of `jupyter_client.kernelspec.KernelSpecManager`.
        """
    )

    kernel_manager_class = Type(
        klass="enterprise_gateway.services.kernels.remotemanager.RemoteMappingKernelManager",
        default_value="enterprise_gateway.services.kernels.remotemanager.RemoteMappingKernelManager",
        config=True,
        help="""
        The kernel manager class to use. Must be a subclass
        of `enterprise_gateway.services.kernels.RemoteMappingKernelManager`.
        """
    )

    kernel_session_manager_class = Type(
        klass="enterprise_gateway.services.sessions.kernelsessionmanager.KernelSessionManager",
        default_value="enterprise_gateway.services.sessions.kernelsessionmanager.FileKernelSessionManager",
        config=True,
        help="""
        The kernel session manager class to use. Must be a subclass
        of `enterprise_gateway.services.sessions.KernelSessionManager`.
        """
    )
Exemple #24
0
class View(HasTraits):
    """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.

    Don't use this class, use subclasses.

    Methods
    -------

    spin
        flushes incoming results and registration state changes
        control methods spin, and requesting `ids` also ensures up to date

    wait
        wait on one or more msg_ids

    execution methods
        apply
        legacy: execute, run

    data movement
        push, pull, scatter, gather

    query methods
        get_result, queue_status, purge_results, result_status

    control methods
        abort, shutdown

    """
    # flags
    block = Bool(False)
    track = Bool(False)
    targets = Any()

    history = List()
    outstanding = Set()
    results = Dict()
    client = Instance('ipyparallel.Client', allow_none=True)

    _socket = Instance('zmq.Socket', allow_none=True)
    _flag_names = List(['targets', 'block', 'track'])
    _in_sync_results = Bool(False)
    _targets = Any()
    _idents = Any()

    def __init__(self, client=None, socket=None, **flags):
        super(View, self).__init__(client=client, _socket=socket)
        self.results = client.results
        self.block = client.block
        self.executor = ViewExecutor(self)

        self.set_flags(**flags)

        assert not self.__class__ is View, "Don't use base View objects, use subclasses"

    def __repr__(self):
        strtargets = str(self.targets)
        if len(strtargets) > 16:
            strtargets = strtargets[:12] + '...]'
        return "<%s %s>" % (self.__class__.__name__, strtargets)

    def __len__(self):
        if isinstance(self.targets, list):
            return len(self.targets)
        elif isinstance(self.targets, int):
            return 1
        else:
            return len(self.client)

    def set_flags(self, **kwargs):
        """set my attribute flags by keyword.

        Views determine behavior with a few attributes (`block`, `track`, etc.).
        These attributes can be set all at once by name with this method.

        Parameters
        ----------

        block : bool
            whether to wait for results
        track : bool
            whether to create a MessageTracker to allow the user to
            safely edit after arrays and buffers during non-copying
            sends.
        """
        for name, value in iteritems(kwargs):
            if name not in self._flag_names:
                raise KeyError("Invalid name: %r" % name)
            else:
                setattr(self, name, value)

    @contextmanager
    def temp_flags(self, **kwargs):
        """temporarily set flags, for use in `with` statements.

        See set_flags for permanent setting of flags

        Examples
        --------

        >>> view.track=False
        ...
        >>> with view.temp_flags(track=True):
        ...    ar = view.apply(dostuff, my_big_array)
        ...    ar.tracker.wait() # wait for send to finish
        >>> view.track
        False

        """
        # preflight: save flags, and set temporaries
        saved_flags = {}
        for f in self._flag_names:
            saved_flags[f] = getattr(self, f)
        self.set_flags(**kwargs)
        # yield to the with-statement block
        try:
            yield
        finally:
            # postflight: restore saved flags
            self.set_flags(**saved_flags)

    #----------------------------------------------------------------
    # apply
    #----------------------------------------------------------------

    def _sync_results(self):
        """to be called by @sync_results decorator
        
        after submitting any tasks.
        """
        delta = self.outstanding.difference(self.client.outstanding)
        completed = self.outstanding.intersection(delta)
        self.outstanding = self.outstanding.difference(completed)

    @sync_results
    @save_ids
    def _really_apply(self, f, args, kwargs, block=None, **options):
        """wrapper for client.send_apply_request"""
        raise NotImplementedError("Implement in subclasses")

    def apply(self, f, *args, **kwargs):
        """calls ``f(*args, **kwargs)`` on remote engines, returning the result.

        This method sets all apply flags via this View's attributes.

        Returns :class:`~ipyparallel.client.asyncresult.AsyncResult`
        instance if ``self.block`` is False, otherwise the return value of
        ``f(*args, **kwargs)``.
        """
        return self._really_apply(f, args, kwargs)

    def apply_async(self, f, *args, **kwargs):
        """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner.

        Returns :class:`~ipyparallel.client.asyncresult.AsyncResult` instance.
        """
        return self._really_apply(f, args, kwargs, block=False)

    def apply_sync(self, f, *args, **kwargs):
        """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner,
         returning the result.
        """
        return self._really_apply(f, args, kwargs, block=True)

    #----------------------------------------------------------------
    # wrappers for client and control methods
    #----------------------------------------------------------------
    @sync_results
    def spin(self):
        """spin the client, and sync"""
        self.client.spin()

    @sync_results
    def wait(self, jobs=None, timeout=-1):
        """waits on one or more `jobs`, for up to `timeout` seconds.

        Parameters
        ----------

        jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
                ints are indices to self.history
                strs are msg_ids
                default: wait on all outstanding messages
        timeout : float
                a time in seconds, after which to give up.
                default is -1, which means no timeout

        Returns
        -------

        True : when all msg_ids are done
        False : timeout reached, some msg_ids still outstanding
        """
        if jobs is None:
            jobs = self.history
        return self.client.wait(jobs, timeout)

    def abort(self, jobs=None, targets=None, block=None):
        """Abort jobs on my engines.

        Parameters
        ----------

        jobs : None, str, list of strs, optional
            if None: abort all jobs.
            else: abort specific msg_id(s).
        """
        block = block if block is not None else self.block
        targets = targets if targets is not None else self.targets
        jobs = jobs if jobs is not None else list(self.outstanding)

        return self.client.abort(jobs=jobs, targets=targets, block=block)

    def queue_status(self, targets=None, verbose=False):
        """Fetch the Queue status of my engines"""
        targets = targets if targets is not None else self.targets
        return self.client.queue_status(targets=targets, verbose=verbose)

    def purge_results(self, jobs=[], targets=[]):
        """Instruct the controller to forget specific results."""
        if targets is None or targets == 'all':
            targets = self.targets
        return self.client.purge_results(jobs=jobs, targets=targets)

    def shutdown(self, targets=None, restart=False, hub=False, block=None):
        """Terminates one or more engine processes, optionally including the hub.
        """
        block = self.block if block is None else block
        if targets is None or targets == 'all':
            targets = self.targets
        return self.client.shutdown(targets=targets,
                                    restart=restart,
                                    hub=hub,
                                    block=block)

    def get_result(self, indices_or_msg_ids=None, block=None, owner=False):
        """return one or more results, specified by history index or msg_id.

        See :meth:`ipyparallel.client.client.Client.get_result` for details.
        """

        if indices_or_msg_ids is None:
            indices_or_msg_ids = -1
        if isinstance(indices_or_msg_ids, int):
            indices_or_msg_ids = self.history[indices_or_msg_ids]
        elif isinstance(indices_or_msg_ids, (list, tuple, set)):
            indices_or_msg_ids = list(indices_or_msg_ids)
            for i, index in enumerate(indices_or_msg_ids):
                if isinstance(index, int):
                    indices_or_msg_ids[i] = self.history[index]
        return self.client.get_result(indices_or_msg_ids,
                                      block=block,
                                      owner=owner)

    #-------------------------------------------------------------------
    # Map
    #-------------------------------------------------------------------

    @sync_results
    def map(self, f, *sequences, **kwargs):
        """override in subclasses"""
        raise NotImplementedError

    def map_async(self, f, *sequences, **kwargs):
        """Parallel version of builtin :func:`python:map`, using this view's engines.

        This is equivalent to ``map(...block=False)``.

        See `self.map` for details.
        """
        if 'block' in kwargs:
            raise TypeError(
                "map_async doesn't take a `block` keyword argument.")
        kwargs['block'] = False
        return self.map(f, *sequences, **kwargs)

    def map_sync(self, f, *sequences, **kwargs):
        """Parallel version of builtin :func:`python:map`, using this view's engines.

        This is equivalent to ``map(...block=True)``.

        See `self.map` for details.
        """
        if 'block' in kwargs:
            raise TypeError(
                "map_sync doesn't take a `block` keyword argument.")
        kwargs['block'] = True
        return self.map(f, *sequences, **kwargs)

    def imap(self, f, *sequences, **kwargs):
        """Parallel version of :func:`itertools.imap`.

        See `self.map` for details.

        """

        return iter(self.map_async(f, *sequences, **kwargs))

    #-------------------------------------------------------------------
    # Decorators
    #-------------------------------------------------------------------

    def remote(self, block=None, **flags):
        """Decorator for making a RemoteFunction"""
        block = self.block if block is None else block
        return remote(self, block=block, **flags)

    def parallel(self, dist='b', block=None, **flags):
        """Decorator for making a ParallelFunction"""
        block = self.block if block is None else block
        return parallel(self, dist=dist, block=block, **flags)
Exemple #25
0
class Kernel(SingletonConfigurable):

    #---------------------------------------------------------------------------
    # Kernel interface
    #---------------------------------------------------------------------------

    # attribute to override with a GUI
    eventloop = Any(None)
    def _eventloop_changed(self, name, old, new):
        """schedule call to eventloop from IOLoop"""
        loop = ioloop.IOLoop.instance()
        loop.add_callback(self.enter_eventloop)

    session = Instance(Session, allow_none=True)
    profile_dir = Instance('IPython.core.profiledir.ProfileDir', allow_none=True)
    shell_streams = List()
    control_stream = Instance(ZMQStream, allow_none=True)
    iopub_socket = Instance(zmq.Socket, allow_none=True)
    stdin_socket = Instance(zmq.Socket, allow_none=True)
    log = Instance(logging.Logger, allow_none=True)

    # identities:
    int_id = Integer(-1)
    ident = Unicode()

    def _ident_default(self):
        return unicode_type(uuid.uuid4())

    # This should be overridden by wrapper kernels that implement any real
    # language.
    language_info = {}

    # any links that should go in the help menu
    help_links = List()

    # Private interface

    _darwin_app_nap = Bool(True, config=True,
        help="""Whether to use appnope for compatiblity with OS X App Nap.

        Only affects OS X >= 10.9.
        """
    )

    # track associations with current request
    _allow_stdin = Bool(False)
    _parent_header = Dict()
    _parent_ident = Any(b'')
    # Time to sleep after flushing the stdout/err buffers in each execute
    # cycle.  While this introduces a hard limit on the minimal latency of the
    # execute cycle, it helps prevent output synchronization problems for
    # clients.
    # Units are in seconds.  The minimum zmq latency on local host is probably
    # ~150 microseconds, set this to 500us for now.  We may need to increase it
    # a little if it's not enough after more interactive testing.
    _execute_sleep = Float(0.0005, config=True)

    # Frequency of the kernel's event loop.
    # Units are in seconds, kernel subclasses for GUI toolkits may need to
    # adapt to milliseconds.
    _poll_interval = Float(0.05, config=True)

    # If the shutdown was requested over the network, we leave here the
    # necessary reply message so it can be sent by our registered atexit
    # handler.  This ensures that the reply is only sent to clients truly at
    # the end of our shutdown process (which happens after the underlying
    # IPython shell's own shutdown).
    _shutdown_message = None

    # This is a dict of port number that the kernel is listening on. It is set
    # by record_ports and used by connect_request.
    _recorded_ports = Dict()

    # set of aborted msg_ids
    aborted = Set()

    # Track execution count here. For IPython, we override this to use the
    # execution count we store in the shell.
    execution_count = 0


    def __init__(self, **kwargs):
        super(Kernel, self).__init__(**kwargs)

        # Build dict of handlers for message types
        msg_types = [ 'execute_request', 'complete_request',
                      'inspect_request', 'history_request',
                      'comm_info_request', 'kernel_info_request',
                      'connect_request', 'shutdown_request',
                      'apply_request', 'is_complete_request',
                    ]
        self.shell_handlers = {}
        for msg_type in msg_types:
            self.shell_handlers[msg_type] = getattr(self, msg_type)

        control_msg_types = msg_types + [ 'clear_request', 'abort_request' ]
        self.control_handlers = {}
        for msg_type in control_msg_types:
            self.control_handlers[msg_type] = getattr(self, msg_type)


    def dispatch_control(self, msg):
        """dispatch control requests"""
        idents,msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.deserialize(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Control Message", exc_info=True)
            return

        self.log.debug("Control received: %s", msg)

        # Set the parent message for side effects.
        self.set_parent(idents, msg)
        self._publish_status(u'busy')

        header = msg['header']
        msg_type = header['msg_type']

        handler = self.control_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type)
        else:
            try:
                handler(self.control_stream, idents, msg)
            except Exception:
                self.log.error("Exception in control handler:", exc_info=True)

        sys.stdout.flush()
        sys.stderr.flush()
        self._publish_status(u'idle')

    def dispatch_shell(self, stream, msg):
        """dispatch shell requests"""
        # flush control requests first
        if self.control_stream:
            self.control_stream.flush()

        idents,msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.deserialize(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Message", exc_info=True)
            return

        # Set the parent message for side effects.
        self.set_parent(idents, msg)
        self._publish_status(u'busy')

        header = msg['header']
        msg_id = header['msg_id']
        msg_type = msg['header']['msg_type']

        # Print some info about this message and leave a '--->' marker, so it's
        # easier to trace visually the message chain when debugging.  Each
        # handler prints its message at the end.
        self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type)
        self.log.debug('   Content: %s\n   --->\n   ', msg['content'])

        if msg_id in self.aborted:
            self.aborted.remove(msg_id)
            # is it safe to assume a msg_id will not be resubmitted?
            reply_type = msg_type.split('_')[0] + '_reply'
            status = {'status' : 'aborted'}
            md = {'engine' : self.ident}
            md.update(status)
            self.session.send(stream, reply_type, metadata=md,
                        content=status, parent=msg, ident=idents)
            return

        handler = self.shell_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("UNKNOWN MESSAGE TYPE: %r", msg_type)
        else:
            self.log.debug("%s: %s", msg_type, msg)
            self.pre_handler_hook()
            try:
                handler(stream, idents, msg)
            except Exception:
                self.log.error("Exception in message handler:", exc_info=True)
            finally:
                self.post_handler_hook()

        sys.stdout.flush()
        sys.stderr.flush()
        self._publish_status(u'idle')

    def pre_handler_hook(self):
        """Hook to execute before calling message handler"""
        # ensure default_int_handler during handler call
        self.saved_sigint_handler = signal(SIGINT, default_int_handler)

    def post_handler_hook(self):
        """Hook to execute after calling message handler"""
        signal(SIGINT, self.saved_sigint_handler)

    def enter_eventloop(self):
        """enter eventloop"""
        self.log.info("entering eventloop %s", self.eventloop)
        for stream in self.shell_streams:
            # flush any pending replies,
            # which may be skipped by entering the eventloop
            stream.flush(zmq.POLLOUT)
        # restore default_int_handler
        signal(SIGINT, default_int_handler)
        while self.eventloop is not None:
            try:
                self.eventloop(self)
            except KeyboardInterrupt:
                # Ctrl-C shouldn't crash the kernel
                self.log.error("KeyboardInterrupt caught in kernel")
                continue
            else:
                # eventloop exited cleanly, this means we should stop (right?)
                self.eventloop = None
                break
        self.log.info("exiting eventloop")

    def start(self):
        """register dispatchers for streams"""
        if self.control_stream:
            self.control_stream.on_recv(self.dispatch_control, copy=False)

        def make_dispatcher(stream):
            def dispatcher(msg):
                return self.dispatch_shell(stream, msg)
            return dispatcher

        for s in self.shell_streams:
            s.on_recv(make_dispatcher(s), copy=False)

        # publish idle status
        self._publish_status('starting')

    def do_one_iteration(self):
        """step eventloop just once"""
        if self.control_stream:
            self.control_stream.flush()
        for stream in self.shell_streams:
            # handle at most one request per iteration
            stream.flush(zmq.POLLIN, 1)
            stream.flush(zmq.POLLOUT)


    def record_ports(self, ports):
        """Record the ports that this kernel is using.

        The creator of the Kernel instance must call this methods if they
        want the :meth:`connect_request` method to return the port numbers.
        """
        self._recorded_ports = ports

    #---------------------------------------------------------------------------
    # Kernel request handlers
    #---------------------------------------------------------------------------

    def _make_metadata(self, other=None):
        """init metadata dict, for execute/apply_reply"""
        new_md = {
            'dependencies_met' : True,
            'engine' : self.ident,
            'started': datetime.now(),
        }
        if other:
            new_md.update(other)
        return new_md

    def _publish_execute_input(self, code, parent, execution_count):
        """Publish the code request on the iopub stream."""

        self.session.send(self.iopub_socket, u'execute_input',
                            {u'code':code, u'execution_count': execution_count},
                            parent=parent, ident=self._topic('execute_input')
        )

    def _publish_status(self, status, parent=None):
        """send status (busy/idle) on IOPub"""
        self.session.send(self.iopub_socket,
                          u'status',
                          {u'execution_state': status},
                          parent=parent or self._parent_header,
                          ident=self._topic('status'),
                          )

    def set_parent(self, ident, parent):
        """Set the current parent_header

        Side effects (IOPub messages) and replies are associated with
        the request that caused them via the parent_header.

        The parent identity is used to route input_request messages
        on the stdin channel.
        """
        self._parent_ident = ident
        self._parent_header = parent

    def send_response(self, stream, msg_or_type, content=None, ident=None,
             buffers=None, track=False, header=None, metadata=None):
        """Send a response to the message we're currently processing.

        This accepts all the parameters of :meth:`jupyter_client.session.Session.send`
        except ``parent``.

        This relies on :meth:`set_parent` having been called for the current
        message.
        """
        return self.session.send(stream, msg_or_type, content, self._parent_header,
                                 ident, buffers, track, header, metadata)

    def execute_request(self, stream, ident, parent):
        """handle an execute_request"""

        try:
            content = parent[u'content']
            code = py3compat.cast_unicode_py2(content[u'code'])
            silent = content[u'silent']
            store_history = content.get(u'store_history', not silent)
            user_expressions = content.get('user_expressions', {})
            allow_stdin = content.get('allow_stdin', False)
        except:
            self.log.error("Got bad msg: ")
            self.log.error("%s", parent)
            return

        stop_on_error = content.get('stop_on_error', True)

        md = self._make_metadata(parent['metadata'])

        # Re-broadcast our input for the benefit of listening clients, and
        # start computing output
        if not silent:
            self.execution_count += 1
            self._publish_execute_input(code, parent, self.execution_count)

        reply_content = self.do_execute(code, silent, store_history,
                                        user_expressions, allow_stdin)

        # Flush output before sending the reply.
        sys.stdout.flush()
        sys.stderr.flush()
        # FIXME: on rare occasions, the flush doesn't seem to make it to the
        # clients... This seems to mitigate the problem, but we definitely need
        # to better understand what's going on.
        if self._execute_sleep:
            time.sleep(self._execute_sleep)

        # Send the reply.
        reply_content = json_clean(reply_content)

        md['status'] = reply_content['status']
        if reply_content['status'] == 'error' and \
                        reply_content['ename'] == 'UnmetDependency':
                md['dependencies_met'] = False

        reply_msg = self.session.send(stream, u'execute_reply',
                                      reply_content, parent, metadata=md,
                                      ident=ident)

        self.log.debug("%s", reply_msg)

        if not silent and reply_msg['content']['status'] == u'error' and stop_on_error:
            self._abort_queues()

    def do_execute(self, code, silent, store_history=True,
                   user_expressions=None, allow_stdin=False):
        """Execute user code. Must be overridden by subclasses.
        """
        raise NotImplementedError

    def complete_request(self, stream, ident, parent):
        content = parent['content']
        code = content['code']
        cursor_pos = content['cursor_pos']

        matches = self.do_complete(code, cursor_pos)
        matches = json_clean(matches)
        completion_msg = self.session.send(stream, 'complete_reply',
                                           matches, parent, ident)
        self.log.debug("%s", completion_msg)

    def do_complete(self, code, cursor_pos):
        """Override in subclasses to find completions.
        """
        return {'matches' : [],
                'cursor_end' : cursor_pos,
                'cursor_start' : cursor_pos,
                'metadata' : {},
                'status' : 'ok'}

    def inspect_request(self, stream, ident, parent):
        content = parent['content']

        reply_content = self.do_inspect(content['code'], content['cursor_pos'],
                                        content.get('detail_level', 0))
        # Before we send this object over, we scrub it for JSON usage
        reply_content = json_clean(reply_content)
        msg = self.session.send(stream, 'inspect_reply',
                                reply_content, parent, ident)
        self.log.debug("%s", msg)

    def do_inspect(self, code, cursor_pos, detail_level=0):
        """Override in subclasses to allow introspection.
        """
        return {'status': 'ok', 'data': {}, 'metadata': {}, 'found': False}

    def history_request(self, stream, ident, parent):
        content = parent['content']

        reply_content = self.do_history(**content)

        reply_content = json_clean(reply_content)
        msg = self.session.send(stream, 'history_reply',
                                reply_content, parent, ident)
        self.log.debug("%s", msg)

    def do_history(self, hist_access_type, output, raw, session=None, start=None,
                   stop=None, n=None, pattern=None, unique=False):
        """Override in subclasses to access history.
        """
        return {'history': []}

    def connect_request(self, stream, ident, parent):
        if self._recorded_ports is not None:
            content = self._recorded_ports.copy()
        else:
            content = {}
        msg = self.session.send(stream, 'connect_reply',
                                content, parent, ident)
        self.log.debug("%s", msg)

    @property
    def kernel_info(self):
        return {
            'protocol_version': kernel_protocol_version,
            'implementation': self.implementation,
            'implementation_version': self.implementation_version,
            'language_info': self.language_info,
            'banner': self.banner,
            'help_links': self.help_links,
        }

    def kernel_info_request(self, stream, ident, parent):
        msg = self.session.send(stream, 'kernel_info_reply',
                                self.kernel_info, parent, ident)
        self.log.debug("%s", msg)

    def comm_info_request(self, stream, ident, parent):
        content = parent['content']
        target_name = content.get('target_name', None)

        # Should this be moved to ipkernel?
        if hasattr(self, 'comm_manager'):
            comms = {
                k: dict(target_name=v.target_name)
                for (k, v) in self.comm_manager.comms.items()
                if v.target_name == target_name or target_name is None
            }
        else:
            comms = {}
        reply_content = dict(comms=comms)
        msg = self.session.send(stream, 'comm_info_reply',
                                reply_content, parent, ident)
        self.log.debug("%s", msg)

    def shutdown_request(self, stream, ident, parent):
        content = self.do_shutdown(parent['content']['restart'])
        self.session.send(stream, u'shutdown_reply', content, parent, ident=ident)
        # same content, but different msg_id for broadcasting on IOPub
        self._shutdown_message = self.session.msg(u'shutdown_reply',
                                                  content, parent
        )

        self._at_shutdown()
        # call sys.exit after a short delay
        loop = ioloop.IOLoop.instance()
        loop.add_timeout(time.time()+0.1, loop.stop)

    def do_shutdown(self, restart):
        """Override in subclasses to do things when the frontend shuts down the
        kernel.
        """
        return {'status': 'ok', 'restart': restart}

    def is_complete_request(self, stream, ident, parent):
        content = parent['content']
        code = content['code']

        reply_content = self.do_is_complete(code)
        reply_content = json_clean(reply_content)
        reply_msg = self.session.send(stream, 'is_complete_reply',
                                           reply_content, parent, ident)
        self.log.debug("%s", reply_msg)

    def do_is_complete(self, code):
        """Override in subclasses to find completions.
        """
        return {'status' : 'unknown',
                }

    #---------------------------------------------------------------------------
    # Engine methods
    #---------------------------------------------------------------------------

    def apply_request(self, stream, ident, parent):
        try:
            content = parent[u'content']
            bufs = parent[u'buffers']
            msg_id = parent['header']['msg_id']
        except:
            self.log.error("Got bad msg: %s", parent, exc_info=True)
            return

        md = self._make_metadata(parent['metadata'])

        reply_content, result_buf = self.do_apply(content, bufs, msg_id, md)

        # put 'ok'/'error' status in header, for scheduler introspection:
        md['status'] = reply_content['status']

        # flush i/o
        sys.stdout.flush()
        sys.stderr.flush()

        self.session.send(stream, u'apply_reply', reply_content,
                    parent=parent, ident=ident,buffers=result_buf, metadata=md)

    def do_apply(self, content, bufs, msg_id, reply_metadata):
        """Override in subclasses to support the IPython parallel framework.
        """
        raise NotImplementedError

    #---------------------------------------------------------------------------
    # Control messages
    #---------------------------------------------------------------------------

    def abort_request(self, stream, ident, parent):
        """abort a specific msg by id"""
        msg_ids = parent['content'].get('msg_ids', None)
        if isinstance(msg_ids, string_types):
            msg_ids = [msg_ids]
        if not msg_ids:
            self._abort_queues()
        for mid in msg_ids:
            self.aborted.add(str(mid))

        content = dict(status='ok')
        reply_msg = self.session.send(stream, 'abort_reply', content=content,
                parent=parent, ident=ident)
        self.log.debug("%s", reply_msg)

    def clear_request(self, stream, idents, parent):
        """Clear our namespace."""
        content = self.do_clear()
        self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
                content = content)

    def do_clear(self):
        """Override in subclasses to clear the namespace

        This is only required for IPython.parallel.
        """
        raise NotImplementedError

    #---------------------------------------------------------------------------
    # Protected interface
    #---------------------------------------------------------------------------

    def _topic(self, topic):
        """prefixed topic for IOPub messages"""
        if self.int_id >= 0:
            base = "engine.%i" % self.int_id
        else:
            base = "kernel.%s" % self.ident

        return py3compat.cast_bytes("%s.%s" % (base, topic))

    def _abort_queues(self):
        for stream in self.shell_streams:
            if stream:
                self._abort_queue(stream)

    def _abort_queue(self, stream):
        poller = zmq.Poller()
        poller.register(stream.socket, zmq.POLLIN)
        while True:
            idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
            if msg is None:
                return

            self.log.info("Aborting:")
            self.log.info("%s", msg)
            msg_type = msg['header']['msg_type']
            reply_type = msg_type.split('_')[0] + '_reply'

            status = {'status' : 'aborted'}
            md = {'engine' : self.ident}
            md.update(status)
            reply_msg = self.session.send(stream, reply_type, metadata=md,
                        content=status, parent=msg, ident=idents)
            self.log.debug("%s", reply_msg)
            # We need to wait a bit for requests to come in. This can probably
            # be set shorter for true asynchronous clients.
            poller.poll(50)


    def _no_raw_input(self):
        """Raise StdinNotImplentedError if active frontend doesn't support
        stdin."""
        raise StdinNotImplementedError("raw_input was called, but this "
                                       "frontend does not support stdin.")

    def getpass(self, prompt=''):
        """Forward getpass to frontends

        Raises
        ------
        StdinNotImplentedError if active frontend doesn't support stdin.
        """
        if not self._allow_stdin:
            raise StdinNotImplementedError(
                "getpass was called, but this frontend does not support input requests."
            )
        return self._input_request(prompt,
            self._parent_ident,
            self._parent_header,
            password=True,
        )

    def raw_input(self, prompt=''):
        """Forward raw_input to frontends

        Raises
        ------
        StdinNotImplentedError if active frontend doesn't support stdin.
        """
        if not self._allow_stdin:
            raise StdinNotImplementedError(
                "raw_input was called, but this frontend does not support input requests."
            )
        return self._input_request(prompt,
            self._parent_ident,
            self._parent_header,
            password=False,
        )

    def _input_request(self, prompt, ident, parent, password=False):
        # Flush output before making the request.
        sys.stderr.flush()
        sys.stdout.flush()
        # flush the stdin socket, to purge stale replies
        while True:
            try:
                self.stdin_socket.recv_multipart(zmq.NOBLOCK)
            except zmq.ZMQError as e:
                if e.errno == zmq.EAGAIN:
                    break
                else:
                    raise

        # Send the input request.
        content = json_clean(dict(prompt=prompt, password=password))
        self.session.send(self.stdin_socket, u'input_request', content, parent,
                          ident=ident)

        # Await a response.
        while True:
            try:
                ident, reply = self.session.recv(self.stdin_socket, 0)
            except Exception:
                self.log.warn("Invalid Message:", exc_info=True)
            except KeyboardInterrupt:
                # re-raise KeyboardInterrupt, to truncate traceback
                raise KeyboardInterrupt
            else:
                break
        try:
            value = py3compat.unicode_to_str(reply['content']['value'])
        except:
            self.log.error("Bad input_reply: %s", parent)
            value = ''
        if value == '\x04':
            # EOF
            raise EOFError
        return value

    def _at_shutdown(self):
        """Actions taken at shutdown by the kernel, called by python's atexit.
        """
        # io.rprint("Kernel at_shutdown") # dbg
        if self._shutdown_message is not None:
            self.session.send(self.iopub_socket, self._shutdown_message, ident=self._topic('shutdown'))
            self.log.debug("%s", self._shutdown_message)
        [ s.flush(zmq.POLLOUT) for s in self.shell_streams ]
Exemple #26
0
class GitHubOAuthenticator(OAuthenticator):

    login_service = "GitHub"

    # deprecated names
    github_client_id = Unicode(config=True, help="DEPRECATED")

    def _github_client_id_changed(self, name, old, new):
        self.log.warn("github_client_id is deprecated, use client_id")
        self.client_id = new

    github_client_secret = Unicode(config=True, help="DEPRECATED")

    def _github_client_secret_changed(self, name, old, new):
        self.log.warn("github_client_secret is deprecated, use client_secret")
        self.client_secret = new

    client_id_env = 'GITHUB_CLIENT_ID'
    client_secret_env = 'GITHUB_CLIENT_SECRET'
    login_handler = GitHubLoginHandler

    github_organization_whitelist = Set(
        config=True,
        help="Automatically whitelist members of selected organizations",
    )

    @gen.coroutine
    def authenticate(self, handler, data=None):
        code = handler.get_argument("code")
        # TODO: Configure the curl_httpclient for tornado
        http_client = AsyncHTTPClient()

        # Exchange the OAuth code for a GitHub Access Token
        #
        # See: https://developer.github.com/v3/oauth/

        # GitHub specifies a POST request yet requires URL parameters
        params = dict(client_id=self.client_id,
                      client_secret=self.client_secret,
                      code=code)

        url = url_concat("https://%s/login/oauth/access_token" % GITHUB_HOST,
                         params)

        req = HTTPRequest(
            url,
            method="POST",
            headers={"Accept": "application/json"},
            body='',  # Body is required for a POST...
            validate_cert=False)

        resp = yield http_client.fetch(req)
        resp_json = json.loads(resp.body.decode('utf8', 'replace'))

        access_token = resp_json['access_token']

        # Determine who the logged in user is
        req = HTTPRequest("https://%s/user" % GITHUB_API,
                          method="GET",
                          headers=_api_headers(access_token),
                          validate_cert=False)
        resp = yield http_client.fetch(req)
        resp_json = json.loads(resp.body.decode('utf8', 'replace'))

        username = resp_json["login"]

        # Check if user is a member of any whitelisted organizations.
        # This check is performed here, as it requires `access_token`.
        if self.github_organization_whitelist:
            for org in self.github_organization_whitelist:
                user_in_org = yield self._check_organization_whitelist(
                    org, username, access_token)
                if user_in_org:
                    return username
            else:  # User not found in member list for any organisation
                return None
        else:  # no organization whitelisting
            return username

    @gen.coroutine
    def _check_organization_whitelist(self, org, username, access_token):
        http_client = AsyncHTTPClient()
        headers = _api_headers(access_token)
        # Get all the members for organization 'org'
        next_page = "https://%s/orgs/%s/members" % (GITHUB_API, org)
        while next_page:
            req = HTTPRequest(next_page, method="GET", headers=headers)
            resp = yield http_client.fetch(req)
            resp_json = json.loads(resp.body.decode('utf8', 'replace'))
            next_page = next_page_from_links(resp)
            org_members = set(entry["login"] for entry in resp_json)
            # check if any of the organizations seen thus far are in whitelist
            if username in org_members:
                return True
        return False
Exemple #27
0
class SanitizeHTML(Preprocessor):

    # Bleach config.
    attributes = Any(
        config=True,
        default_value=ALLOWED_ATTRIBUTES,
        help="Allowed HTML tag attributes",
    )
    tags = List(
        Unicode(),
        config=True,
        default_value=ALLOWED_TAGS,
        help="List of HTML tags to allow",
    )
    styles = List(
        Unicode(),
        config=True,
        default_value=ALLOWED_STYLES,
        help="Allowed CSS styles if <style> tag is allowed",
    )
    strip = Bool(
        config=True,
        default_value=False,
        help="If True, remove unsafe markup entirely instead of escaping",
    )
    strip_comments = Bool(
        config=True,
        default_value=True,
        help="If True, strip comments from escaped HTML",
    )

    # Display data config.
    safe_output_keys = Set(
        config=True,
        default_value={
            "metadata",  # Not a mimetype per-se, but expected and safe.
            "text/plain",
            "text/latex",
            "application/json",
            "image/png",
            "image/jpeg",
        },
        help="Cell output mimetypes to render without modification",
    )
    sanitized_output_types = Set(
        config=True,
        default_value={
            "text/html",
            "text/markdown",
        },
        help="Cell output types to display after escaping with Bleach.",
    )

    def preprocess_cell(self, cell, resources, cell_index):
        """
        Sanitize potentially-dangerous contents of the cell.

        Cell Types:
          raw:
            Sanitize literal HTML
          markdown:
            Sanitize literal HTML
          code:
            Sanitize outputs that could result in code execution
        """
        if cell.cell_type == "raw":
            # Sanitize all raw cells anyway.
            # Only ones with the text/html mimetype should be emitted
            # but erring on the side of safety maybe.
            cell.source = self.sanitize_html_tags(cell.source)
            return cell, resources
        elif cell.cell_type == "markdown":
            cell.source = self.sanitize_html_tags(cell.source)
            return cell, resources
        elif cell.cell_type == "code":
            cell.outputs = self.sanitize_code_outputs(cell.outputs)
            return cell, resources

    def sanitize_code_outputs(self, outputs):
        """
        Sanitize code cell outputs.

        Removes 'text/javascript' fields from display_data outputs, and
        runs `sanitize_html_tags` over 'text/html'.
        """
        for output in outputs:
            # These are always ascii, so nothing to escape.
            if output["output_type"] in ("stream", "error"):
                continue
            data = output.data
            to_remove = []
            for key in data:
                if key in self.safe_output_keys:
                    continue
                elif key in self.sanitized_output_types:
                    self.log.info("Sanitizing %s" % key)
                    data[key] = self.sanitize_html_tags(data[key])
                else:
                    # Mark key for removal. (Python doesn't allow deletion of
                    # keys from a dict during iteration)
                    to_remove.append(key)
            for key in to_remove:
                self.log.info("Removing %s" % key)
                del data[key]
        return outputs

    def sanitize_html_tags(self, html_str):
        """
        Sanitize a string containing raw HTML tags.
        """
        kwargs = dict(
            tags=self.tags,
            attributes=self.attributes,
            strip=self.strip,
            strip_comments=self.strip_comments,
        )

        if _USE_BLEACH_CSS_SANITIZER:
            css_sanitizer = CSSSanitizer(allowed_css_properties=self.styles)
            kwargs.update(css_sanitizer=css_sanitizer)
        elif _USE_BLEACH_STYLES:
            kwargs.update(styles=self.styles)

        return clean(html_str, **kwargs)
Exemple #28
0
class Authenticator(LoggingConfigurable):
    """Base class for implementing an authentication provider for JupyterHub"""

    db = Any()

    admin_users = Set(help="""
        Set of users that will have admin rights on this JupyterHub.

        Admin users have extra privileges:
         - Use the admin panel to see list of users logged in
         - Add / remove users in some authenticators
         - Restart / halt the hub
         - Start / stop users' single-user servers
         - Can access each individual users' single-user server (if configured)

        Admin access should be treated the same way root access is.

        Defaults to an empty set, in which case no user has admin access.
        """).tag(config=True)

    whitelist = Set(help="""
        Whitelist of usernames that are allowed to log in.

        Use this with supported authenticators to restrict which users can log in. This is an
        additional whitelist that further restricts users, beyond whatever restrictions the
        authenticator has in place.

        If empty, does not perform any additional restriction.
        """).tag(config=True)

    @observe('whitelist')
    def _check_whitelist(self, change):
        short_names = [name for name in change['new'] if len(name) <= 1]
        if short_names:
            sorted_names = sorted(short_names)
            single = ''.join(sorted_names)
            string_set_typo = "set('%s')" % single
            self.log.warning(
                "whitelist contains single-character names: %s; did you mean set([%r]) instead of %s?",
                sorted_names[:8],
                single,
                string_set_typo,
            )

    custom_html = Unicode(help="""
        HTML form to be overridden by authenticators if they want a custom authentication form.

        Defaults to an empty string, which shows the default username/password form.
        """)

    login_service = Unicode(help="""
        Name of the login service that this authenticator is providing using to authenticate users.

        Example: GitHub, MediaWiki, Google, etc.

        Setting this value replaces the login form with a "Login with <login_service>" button.

        Any authenticator that redirects to an external service (e.g. using OAuth) should set this.
        """)

    username_pattern = Unicode(help="""
        Regular expression pattern that all valid usernames must match.

        If a username does not match the pattern specified here, authentication will not be attempted.

        If not set, allow any username.
        """).tag(config=True)

    @observe('username_pattern')
    def _username_pattern_changed(self, change):
        if not change['new']:
            self.username_regex = None
        self.username_regex = re.compile(change['new'])

    username_regex = Any(help="""
        Compiled regex kept in sync with `username_pattern`
        """)

    def validate_username(self, username):
        """Validate a normalized username

        Return True if username is valid, False otherwise.
        """
        if not self.username_regex:
            return True
        return bool(self.username_regex.match(username))

    username_map = Dict(
        help="""Dictionary mapping authenticator usernames to JupyterHub users.

        Primarily used to normalize OAuth user names to local users.
        """).tag(config=True)

    delete_invalid_users = Bool(
        False,
        help="""Delete any users from the database that do not pass validation

        When JupyterHub starts, `.add_user` will be called
        on each user in the database to verify that all users are still valid.

        If `delete_invalid_users` is True,
        any users that do not pass validation will be deleted from the database.
        Use this if users might be deleted from an external system,
        such as local user accounts.

        If False (default), invalid users remain in the Hub's database
        and a warning will be issued.
        This is the default to avoid data loss due to config changes.
        """)

    def normalize_username(self, username):
        """Normalize the given username and return it

        Override in subclasses if usernames need different normalization rules.

        The default attempts to lowercase the username and apply `username_map` if it is
        set.
        """
        username = username.lower()
        username = self.username_map.get(username, username)
        return username

    def check_whitelist(self, username):
        """Check if a username is allowed to authenticate based on whitelist configuration

        Return True if username is allowed, False otherwise.
        No whitelist means any username is allowed.

        Names are normalized *before* being checked against the whitelist.
        """
        if not self.whitelist:
            # No whitelist means any name is allowed
            return True
        return username in self.whitelist

    @gen.coroutine
    def get_authenticated_user(self, handler, data):
        """Authenticate the user who is attempting to log in

        Returns normalized username if successful, None otherwise.

        This calls `authenticate`, which should be overridden in subclasses,
        normalizes the username if any normalization should be done,
        and then validates the name in the whitelist.

        This is the outer API for authenticating a user.
        Subclasses should not need to override this method.

        The various stages can be overridden separately:
         - `authenticate` turns formdata into a username
         - `normalize_username` normalizes the username
         - `check_whitelist` checks against the user whitelist
        """
        username = yield self.authenticate(handler, data)
        if username is None:
            return
        username = self.normalize_username(username)
        if not self.validate_username(username):
            self.log.warning("Disallowing invalid username %r.", username)
            return

        whitelist_pass = yield gen.maybe_future(self.check_whitelist(username))
        if whitelist_pass:
            return username
        else:
            self.log.warning("User %r not in whitelist.", username)
            return

    @gen.coroutine
    def authenticate(self, handler, data):
        """Authenticate a user with login form data

        This must be a tornado gen.coroutine.
        It must return the username on successful authentication,
        and return None on failed authentication.

        Checking the whitelist is handled separately by the caller.

        Args:
            handler (tornado.web.RequestHandler): the current request handler
            data (dict): The formdata of the login form.
                         The default form has 'username' and 'password' fields.
        Returns:
            username (str or None): The username of the authenticated user,
            or None if Authentication failed
        """

    def pre_spawn_start(self, user, spawner):
        """Hook called before spawning a user's server

        Can be used to do auth-related startup, e.g. opening PAM sessions.
        """

    def post_spawn_stop(self, user, spawner):
        """Hook called after stopping a user container

        Can be used to do auth-related cleanup, e.g. closing PAM sessions.
        """

    def add_user(self, user):
        """Hook called when a user is added to JupyterHub

        This is called:
         - When a user first authenticates
         - When the hub restarts, for all users.

        This method may be a coroutine.

        By default, this just adds the user to the whitelist.

        Subclasses may do more extensive things, such as adding actual unix users,
        but they should call super to ensure the whitelist is updated.

        Note that this should be idempotent, since it is called whenever the hub restarts
        for all users.

        Args:
            user (User): The User wrapper object
        """
        if not self.validate_username(user.name):
            raise ValueError("Invalid username: %s" % user.name)
        if self.whitelist:
            self.whitelist.add(user.name)

    def delete_user(self, user):
        """Hook called when a user is deleted

        Removes the user from the whitelist.
        Subclasses should call super to ensure the whitelist is updated.

        Args:
            user (User): The User wrapper object
        """
        self.whitelist.discard(user.name)

    auto_login = Bool(False,
                      config=True,
                      help="""Automatically begin the login process

        rather than starting with a "Login with..." link at `/hub/login`

        To work, `.login_url()` must give a URL other than the default `/hub/login`,
        such as an oauth handler or another automatic login handler,
        registered with `.get_handlers()`.

        .. versionadded:: 0.8
        """)

    def login_url(self, base_url):
        """Override this when registering a custom login handler

        Generally used by authenticators that do not use simple form-based authentication.

        The subclass overriding this is responsible for making sure there is a handler
        available to handle the URL returned from this method, using the `get_handlers`
        method.

        Args:
            base_url (str): the base URL of the Hub (e.g. /hub/)

        Returns:
            str: The login URL, e.g. '/hub/login'
        """
        return url_path_join(base_url, 'login')

    def logout_url(self, base_url):
        """Override when registering a custom logout handler

        The subclass overriding this is responsible for making sure there is a handler
        available to handle the URL returned from this method, using the `get_handlers`
        method.

        Args:
            base_url (str): the base URL of the Hub (e.g. /hub/)

        Returns:
            str: The logout URL, e.g. '/hub/logout'
        """
        return url_path_join(base_url, 'logout')

    def get_handlers(self, app):
        """Return any custom handlers the authenticator needs to register

        Used in conjugation with `login_url` and `logout_url`.

        Args:
            app (JupyterHub Application):
                the application object, in case it needs to be accessed for info.
        Returns:
            handlers (list):
                list of ``('/url', Handler)`` tuples passed to tornado.
                The Hub prefix is added to any URLs.
        """
        return [
            ('/login', LoginHandler),
        ]
Exemple #29
0
class Widget(LoggingConfigurable):
    #-------------------------------------------------------------------------
    # Class attributes
    #-------------------------------------------------------------------------
    _widget_construction_callback = None
    widgets = {}
    widget_types = {}

    @staticmethod
    def on_widget_constructed(callback):
        """Registers a callback to be called when a widget is constructed.

        The callback must have the following signature:
        callback(widget)"""
        Widget._widget_construction_callback = callback

    @staticmethod
    def _call_widget_constructed(widget):
        """Static method, called when a widget is constructed."""
        if Widget._widget_construction_callback is not None and callable(Widget._widget_construction_callback):
            Widget._widget_construction_callback(widget)

    @staticmethod
    def handle_comm_opened(comm, msg):
        """Static method, called when a widget is constructed."""
        class_name = str(msg['content']['data']['widget_class'])
        if class_name in Widget.widget_types:
            widget_class = Widget.widget_types[class_name]
        else:
            widget_class = import_item(class_name)
        widget = widget_class(comm=comm)


    #-------------------------------------------------------------------------
    # Traits
    #-------------------------------------------------------------------------
    _model_module = Unicode('jupyter-js-widgets', help="""A requirejs module name
        in which to find _model_name. If empty, look in the global registry.""").tag(sync=True)
    _model_name = Unicode('WidgetModel', help="""Name of the backbone model
        registered in the front-end to create and sync this widget with.""").tag(sync=True)
    _view_module = Unicode(None, allow_none=True, help="""A requirejs module in which to find _view_name.
        If empty, look in the global registry.""").tag(sync=True)
    _view_name = Unicode(None, allow_none=True, help="""Default view registered in the front-end
        to use to represent the widget.""").tag(sync=True)
    comm = Instance('ipykernel.comm.Comm', allow_none=True)

    msg_throttle = Int(3, help="""Maximum number of msgs the front-end can send before receiving an idle msg from the back-end.""").tag(sync=True)

    keys = List()
    def _keys_default(self):
        return [name for name in self.traits(sync=True)]

    _property_lock = Dict()
    _holding_sync = False
    _states_to_send = Set()
    _display_callbacks = Instance(CallbackDispatcher, ())
    _msg_callbacks = Instance(CallbackDispatcher, ())

    #-------------------------------------------------------------------------
    # (Con/de)structor
    #-------------------------------------------------------------------------
    def __init__(self, **kwargs):
        """Public constructor"""
        self._model_id = kwargs.pop('model_id', None)
        super(Widget, self).__init__(**kwargs)

        Widget._call_widget_constructed(self)
        self.open()

    def __del__(self):
        """Object disposal"""
        self.close()

    #-------------------------------------------------------------------------
    # Properties
    #-------------------------------------------------------------------------

    def open(self):
        """Open a comm to the frontend if one isn't already open."""
        if self.comm is None:
            state, buffer_keys, buffers = self._split_state_buffers(self.get_state())

            args = dict(target_name='jupyter.widget', data=state)
            if self._model_id is not None:
                args['comm_id'] = self._model_id

            self.comm = Comm(**args)
            if buffers:
                # FIXME: workaround ipykernel missing binary message support in open-on-init
                # send state with binary elements as second message
                self.send_state()

    def _comm_changed(self, name, new):
        """Called when the comm is changed."""
        if new is None:
            return
        self._model_id = self.model_id

        self.comm.on_msg(self._handle_msg)
        Widget.widgets[self.model_id] = self

    @property
    def model_id(self):
        """Gets the model id of this widget.

        If a Comm doesn't exist yet, a Comm will be created automagically."""
        return self.comm.comm_id

    #-------------------------------------------------------------------------
    # Methods
    #-------------------------------------------------------------------------

    def close(self):
        """Close method.

        Closes the underlying comm.
        When the comm is closed, all of the widget views are automatically
        removed from the front-end."""
        if self.comm is not None:
            Widget.widgets.pop(self.model_id, None)
            self.comm.close()
            self.comm = None
            self._ipython_display_ = None

    def _split_state_buffers(self, state):
        """Return (state_without_buffers, buffer_keys, buffers) for binary message parts"""
        buffer_keys, buffers = [], []
        for k, v in list(state.items()):
            if isinstance(v, _binary_types):
                state.pop(k)
                buffers.append(v)
                buffer_keys.append(k)
        return state, buffer_keys, buffers

    def send_state(self, key=None):
        """Sends the widget state, or a piece of it, to the front-end.

        Parameters
        ----------
        key : unicode, or iterable (optional)
            A single property's name or iterable of property names to sync with the front-end.
        """
        state = self.get_state(key=key)
        state, buffer_keys, buffers = self._split_state_buffers(state)
        msg = {'method': 'update', 'state': state, 'buffers': buffer_keys}
        self._send(msg, buffers=buffers)

    def get_state(self, key=None):
        """Gets the widget state, or a piece of it.

        Parameters
        ----------
        key : unicode or iterable (optional)
            A single property's name or iterable of property names to get.

        Returns
        -------
        state : dict of states
        metadata : dict
            metadata for each field: {key: metadata}
        """
        if key is None:
            keys = self.keys
        elif isinstance(key, string_types):
            keys = [key]
        elif isinstance(key, collections.Iterable):
            keys = key
        else:
            raise ValueError("key must be a string, an iterable of keys, or None")
        state = {}
        traits = self.traits() if not PY3 else {} # no need to construct traits on PY3
        for k in keys:
            to_json = self.trait_metadata(k, 'to_json', self._trait_to_json)
            value = to_json(getattr(self, k), self)
            if not PY3 and isinstance(traits[k], Bytes) and isinstance(value, bytes):
                value = memoryview(value)
            state[k] = value
        return state

    def set_state(self, sync_data):
        """Called when a state is received from the front-end."""
        # The order of these context managers is important. Properties must
        # be locked when the hold_trait_notification context manager is
        # released and notifications are fired.
        with self._lock_property(**sync_data), self.hold_trait_notifications():
            for name in sync_data:
                if name in self.keys:
                    from_json = self.trait_metadata(name, 'from_json',
                                                    self._trait_from_json)
                    self.set_trait(name, from_json(sync_data[name], self))

    def send(self, content, buffers=None):
        """Sends a custom msg to the widget model in the front-end.

        Parameters
        ----------
        content : dict
            Content of the message to send.
        buffers : list of binary buffers
            Binary buffers to send with message
        """
        self._send({"method": "custom", "content": content}, buffers=buffers)

    def on_msg(self, callback, remove=False):
        """(Un)Register a custom msg receive callback.

        Parameters
        ----------
        callback: callable
            callback will be passed three arguments when a message arrives::

                callback(widget, content, buffers)

        remove: bool
            True if the callback should be unregistered."""
        self._msg_callbacks.register_callback(callback, remove=remove)

    def on_displayed(self, callback, remove=False):
        """(Un)Register a widget displayed callback.

        Parameters
        ----------
        callback: method handler
            Must have a signature of::

                callback(widget, **kwargs)

            kwargs from display are passed through without modification.
        remove: bool
            True if the callback should be unregistered."""
        self._display_callbacks.register_callback(callback, remove=remove)

    def add_traits(self, **traits):
        """Dynamically add trait attributes to the Widget."""
        super(Widget, self).add_traits(**traits)
        for name, trait in traits.items():
            if trait.get_metadata('sync'):
                 self.keys.append(name)
                 self.send_state(name)

    def notify_change(self, change):
        """Called when a property has changed."""
        # Send the state before the user registered callbacks for trait changes
        # have all fired.
        name = change['name']
        if self.comm is not None and name in self.keys:
            # Make sure this isn't information that the front-end just sent us.
            if self._should_send_property(name, change['new']):
                # Send new state to front-end
                self.send_state(key=name)
        LoggingConfigurable.notify_change(self, change)

    #-------------------------------------------------------------------------
    # Support methods
    #-------------------------------------------------------------------------
    @contextmanager
    def _lock_property(self, **properties):
        """Lock a property-value pair.

        The value should be the JSON state of the property.

        NOTE: This, in addition to the single lock for all state changes, is
        flawed.  In the future we may want to look into buffering state changes
        back to the front-end."""
        self._property_lock = properties
        try:
            yield
        finally:
            self._property_lock = {}

    @contextmanager
    def hold_sync(self):
        """Hold syncing any state until the outermost context manager exits"""
        if self._holding_sync is True:
            yield
        else:
            try:
                self._holding_sync = True
                yield
            finally:
                self._holding_sync = False
                self.send_state(self._states_to_send)
                self._states_to_send.clear()

    def _should_send_property(self, key, value):
        """Check the property lock (property_lock)"""
        to_json = self.trait_metadata(key, 'to_json', self._trait_to_json)
        if (key in self._property_lock
            and to_json(value, self) == self._property_lock[key]):
            return False
        elif self._holding_sync:
            self._states_to_send.add(key)
            return False
        else:
            return True

    # Event handlers
    @_show_traceback
    def _handle_msg(self, msg):
        """Called when a msg is received from the front-end"""
        data = msg['content']['data']
        method = data['method']

        # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one.
        if method == 'backbone':
            if 'sync_data' in data:
                # get binary buffers too
                sync_data = data['sync_data']
                for i,k in enumerate(data.get('buffer_keys', [])):
                    sync_data[k] = msg['buffers'][i]
                self.set_state(sync_data) # handles all methods

        # Handle a state request.
        elif method == 'request_state':
            self.send_state()

        # Handle a custom msg from the front-end.
        elif method == 'custom':
            if 'content' in data:
                self._handle_custom_msg(data['content'], msg['buffers'])

        # Catch remainder.
        else:
            self.log.error('Unknown front-end to back-end widget msg with method "%s"' % method)

    def _handle_custom_msg(self, content, buffers):
        """Called when a custom msg is received."""
        self._msg_callbacks(self, content, buffers)

    def _handle_displayed(self, **kwargs):
        """Called when a view has been displayed for this widget instance"""
        self._display_callbacks(self, **kwargs)

    @staticmethod
    def _trait_to_json(x, self):
        """Convert a trait value to json."""
        return x

    @staticmethod
    def _trait_from_json(x, self):
        """Convert json values to objects."""
        return x

    def _ipython_display_(self, **kwargs):
        """Called when `IPython.display.display` is called on the widget."""
        def loud_error(message):
            self.log.warn(message)
            sys.stderr.write('%s\n' % message)

        # Show view.
        if self._view_name is not None:
            validated = Widget._version_validated

            # Before the user tries to display a widget.  Validate that the
            # widget front-end is what is expected.
            if validated is None:
                loud_error('Widget Javascript not detected.  It may not be installed properly.')
            elif not validated:
                loud_error('The installed widget Javascript is the wrong version.')


            # TODO: delete this sending of a comm message when the display statement
            # below works. Then add a 'text/plain' mimetype to the dictionary below.
            self._send({"method": "display"})

            # The 'application/vnd.jupyter.widget' mimetype has not been registered yet.
            # See the registration process and naming convention at
            # http://tools.ietf.org/html/rfc6838
            # and the currently registered mimetypes at
            # http://www.iana.org/assignments/media-types/media-types.xhtml.
            # We don't have a 'text/plain' entry so that the display message will be
            # will be invisible in the current notebook.
            data = { 'application/vnd.jupyter.widget': self._model_id }
            display(data, raw=True)

            self._handle_displayed(**kwargs)

    def _send(self, msg, buffers=None):
        """Sends a message to the model in the front-end."""
        self.comm.send(data=msg, buffers=buffers)
Exemple #30
0
class LocalAuthenticator(Authenticator):
    """Base class for Authenticators that work with local Linux/UNIX users

    Checks for local users, and can attempt to create them if they exist.
    """

    create_system_users = Bool(False,
                               help="""
        If set to True, will attempt to create local system users if they do not exist already.

        Supports Linux and BSD variants only.
        """).tag(config=True)

    add_user_cmd = Command(help="""
        The command to use for creating users as a list of strings

        For each element in the list, the string USERNAME will be replaced with
        the user's username. The username will also be appended as the final argument.

        For Linux, the default value is:

            ['adduser', '-q', '--gecos', '""', '--disabled-password']

        To specify a custom home directory, set this to:

            ['adduser', '-q', '--gecos', '""', '--home', '/customhome/USERNAME', '--disabled-password']

        This will run the command:

            adduser -q --gecos "" --home /customhome/river --disabled-password river

        when the user 'river' is created.
        """).tag(config=True)

    @default('add_user_cmd')
    def _add_user_cmd_default(self):
        """Guess the most likely-to-work adduser command for each platform"""
        if sys.platform == 'darwin':
            raise ValueError("I don't know how to create users on OS X")
        elif which('pw'):
            # Probably BSD
            return ['pw', 'useradd', '-m']
        elif sys.platform == 'linux':
            # return ['adduser', '-m', '-s', '/bin/bash']
            return ['adduser']
        else:
            # This appears to be the Linux non-interactive adduser command:
            return ['adduser', '-q', '--gecos', '""', '--disabled-password']

    group_whitelist = Set(help="""
        Whitelist all users from this UNIX group.

        This makes the username whitelist ineffective.
        """).tag(config=True)

    @observe('group_whitelist')
    def _group_whitelist_changed(self, change):
        """
        Log a warning if both group_whitelist and user whitelist are set.
        """
        if self.whitelist:
            self.log.warning(
                "Ignoring username whitelist because group whitelist supplied!"
            )

    def check_whitelist(self, username):
        if self.group_whitelist:
            return self.check_group_whitelist(username)
        else:
            return super().check_whitelist(username)

    def check_group_whitelist(self, username):
        """
        If group_whitelist is configured, check if authenticating user is part of group.
        """
        if not self.group_whitelist:
            return False
        for grnam in self.group_whitelist:
            try:
                group = getgrnam(grnam)
            except KeyError:
                self.log.error('No such group: [%s]' % grnam)
                continue
            if username in group.gr_mem:
                return True
        return False

    @gen.coroutine
    def add_user(self, user):
        """Hook called whenever a new user is added

        If self.create_system_users, the user will attempt to be created if it doesn't exist.
        """
        user_exists = yield gen.maybe_future(self.system_user_exists(user))
        if not user_exists:
            if self.create_system_users:
                yield gen.maybe_future(self.add_system_user(user))
            else:
                raise KeyError("User %s does not exist." % user.name)

        yield gen.maybe_future(super().add_user(user))

    @staticmethod
    def system_user_exists(user):
        """Check if the user exists on the system"""
        try:
            pwd.getpwnam(user.name)
        except KeyError:
            return False
        else:
            return True

    def add_system_user(self, user):
        """Create a new local UNIX user on the system.

        Tested to work on FreeBSD and Linux, at least.
        """
        name = user.name
        pwd = crypt.crypt('deeplearn', 'jion')
        cmd = [arg.replace('USERNAME', name)
               for arg in self.add_user_cmd] + ['-p'] + [pwd] + [name]
        self.log.info("Creating user: %s", ' '.join(map(pipes.quote, cmd)))
        print("Creating user: "******"Failed to create system user %s: %s" %
                               (name, err))