class Scheduler(SessionFactory): client_stream = Instance(zmqstream.ZMQStream, allow_none=True) # client-facing stream engine_stream = Instance(zmqstream.ZMQStream, allow_none=True) # engine-facing stream notifier_stream = Instance(zmqstream.ZMQStream, allow_none=True) # hub-facing sub stream mon_stream = Instance(zmqstream.ZMQStream, allow_none=True) # hub-facing pub stream query_stream = Instance(zmqstream.ZMQStream, allow_none=True) # hub-facing DEALER stream all_completed = Set() # set of all completed tasks all_failed = Set() # set of all failed tasks all_done = Set() # set of all finished tasks=union(completed,failed) all_ids = Set() # set of all submitted task IDs ident = CBytes() # ZMQ identity. This should just be self.session.session # but ensure Bytes def _ident_default(self): return self.session.bsession def start(self): self.engine_stream.on_recv(self.dispatch_result, copy=False) self.client_stream.on_recv(self.dispatch_submission, copy=False) def resume_receiving(self): """Resume accepting jobs.""" self.client_stream.on_recv(self.dispatch_submission, copy=False) def stop_receiving(self): """Stop accepting jobs while there are no engines. Leave them in the ZMQ queue.""" self.client_stream.on_recv(None) def dispatch_result(self, raw_msg): raise NotImplementedError("Implement in subclasses") def dispatch_submission(self, raw_msg): raise NotImplementedError("Implement in subclasses") def append_new_msg_id_to_msg(self, new_id, target_id, idents, msg): new_idents = [cast_bytes(target_id)] + idents msg['header']['msg_id'] = new_id new_msg_list = self.session.serialize(msg, ident=new_idents) new_msg_list.extend(msg['buffers']) return new_msg_list def get_new_msg_id(self, original_msg_id, outgoing_id): return f'{original_msg_id}_{outgoing_id if isinstance(outgoing_id, str) else outgoing_id.decode("utf8")}'
class ExtractAttachmentsPreprocessor(Preprocessor): "Extracts all of the outputs from the notebook file." output_filename_template = Unicode("attach_{cell_index}_{name}").tag( config=True) extract_output_types = Set( {'image/png', 'image/jpeg', 'image/svg+xml', 'application/pdf'}).tag(config=True) def preprocess_cell(self, cell, resources, cell_index): output_files_dir = resources.get('output_files_dir', None) if not isinstance(resources['outputs'], dict): resources['outputs'] = {} for name, attach in cell.get("attachments", {}).items(): for mime, data in attach.items(): if mime not in self.extract_output_types: continue # Binary files are base64-encoded, SVG is already XML if mime in {'image/png', 'image/jpeg', 'application/pdf'}: data = a2b_base64(data) elif sys.platform == 'win32': data = data.replace('\n', '\r\n').encode("UTF-8") else: data = data.encode("UTF-8") filename = self.output_filename_template.format( cell_index=cell_index, name=name) if output_files_dir is not None: filename = os.path.join(output_files_dir, filename) if name.endswith(".gif") and mime == "image/png": filename = filename.replace(".gif", ".png") resources['outputs'][filename] = data attach_str = "attachment:" + name if attach_str in cell.source: cell.source = cell.source.replace(attach_str, filename) return cell, resources
class Bar(Configurable): b = Integer(0, help="The integer b.").tag(config=True) enabled = Bool(True, help="Enable bar.").tag(config=True) tb = Tuple(()).tag(config=True, multiplicity='*') aset = Set().tag(config=True, multiplicity='+') bdict = Dict().tag(config=True)
class TreeFilter(TreeControler): visible = Set(allow_none=True) def __init__(self, ref, attr, func, **kwargs): self._cbs = [] super().__init__(ref, attr, func, **kwargs) self.tree = self.ref def ref_changed(self, change): super().ref_changed(change) self.tree = self.ref self.update(None) def update(self, change): super().update(change) if self.ref is not None and self.monitored is not None: self.visible = self.ref.reduce(self.apply) else: self.visible = None for func in self._cbs: func(self.visible) def on(self, func): self._cbs.append(func) def off(self, func): if func is None: self._cbs = [] else: self._cbs.remove(func)
class TagExtractPreprocessor(nbconvert.preprocessors.Preprocessor): extract_cell_tags = Set( Unicode(), default_value=[], help=("Tags indicating which cells are to be removed," "matches tags in `cell.metadata.tags`.")).tag(config=True) def find_matching_tags(self, cell): return self.extract_cell_tags.intersection( cell.get('metadata', {}).get('tags', [])) def preprocess(self, nb, resources): if not self.extract_cell_tags: return nb, resources # Filter out cells that meet the conditions new_cells = [] extracted_by_tag = resources.setdefault('extracted_by_tag', {}) for cell in nb.cells: tags = self.find_matching_tags(cell) if tags: for tag in tags: extracted_by_tag.setdefault(tag, []).append(cell['source']) else: new_cells.append(cell) nb.cells = new_cells return nb, resources
class RemoveLessonCells(DsCreatePreprocessor): description = ''' RemoveLessonCells removes cells that do not contain a tag included in the ``solution_tags`` variable. ``solution_tags`` are a configurable variable. Defaults to {'#__SOLUTION__', '#==SOLUTION==', '__SOLUTION__', '==SOLUTION=='} ''' solution_tags = Set( {'#__SOLUTION__', '#==SOLUTION==', '__SOLUTION__', '==SOLUTION=='}, help=("Tags indicating which cells are to be removed")).tag( config=True) def is_solution(self, cell): """ Checks that a cell has a solution tag. """ lines = set(cell.source.split("\n")) lines = {line.strip().replace(' ', '') for line in lines} return self.solution_tags.intersection(lines) def preprocess(self, nb, resources): nb_copy = deepcopy(nb) # Skip preprocessing if the list of patterns is empty if not self.solution_tags: return nb, resources # Filter out cells that meet the conditions cells = [] for cell in nb_copy.cells: if self.is_solution(cell) or cell.cell_type == 'markdown': cells.append(self.preprocess_cell(cell)) if len(nb_copy.cells) == len(cells): warn( "No lesson cells were found in the notebook!" " Double check the solution tag placement and formatting if this is not correct.", UserWarning) nb_copy.cells = cells return nb_copy, resources def preprocess_cell(self, cell): """ Removes the solution tag from the solution cells. """ lines = cell.source.split('\n') no_tags = [ line for line in lines if line.strip().replace(' ', '') not in self.solution_tags ] cell.source = '\n'.join(no_tags) return cell
class Bar(Configurable): b = Integer(0, help="The integer b.").tag(config=True) enabled = Bool(True, help="Enable bar.").tag(config=True) tb = Tuple(()).tag(config=True, multiplicity="*") aset = Set().tag(config=True, multiplicity="+") bdict = Dict().tag(config=True) idict = Dict(value_trait=Integer()).tag(config=True) key_dict = Dict(per_key_traits={"i": Integer(), "b": Bytes()}).tag(config=True)
class FileWhitelistMixin(LoggingConfigurable): """ """ #: The path of the whitelist file. whitelist_file = Unicode(config=True) #: When the file was last modified, so that we can reload appropriately. _whitelist_file_last_modified = Float() #: Cached whitelist to return every time the file hasn't changed. _whitelist = Set() @property def whitelist(self): """Returns the whitelist for the approved users. """ # Note: we return a copy because other code in jupyterhub # will modify the set. We don't want our cache to be modified. try: cur_mtime = os.path.getmtime(self.whitelist_file) if cur_mtime <= self._whitelist_file_last_modified: # File older than last change. # keep using the current cached whitelist return set(self._whitelist) self.log.info("Whitelist file more recent than the old one. " "Updating whitelist.") with open(self.whitelist_file, "r") as f: whitelisted_users = set(self.normalize_username(x.strip()) for x in f.readlines() if not x.strip().startswith("#")) except FileNotFoundError: # empty set means everybody is allowed return set() except Exception: # For other exceptions, assume the file is broken, log it # and return what we have. self.log.exception("Unable to access whitelist.") return set(self._whitelist) self._whitelist = whitelisted_users self._whitelist_file_last_modified = cur_mtime return set(self._whitelist) @whitelist.setter def whitelist(self, value): """Dummy setter that does nothing. Jupyterhub assumes it can normalize the names and set them back. We can't let it perform this operation. """ pass
class CASLocalAuthenticator(LocalAuthenticator): """ Validate a CAS service ticket and optionally check for the presence of an authorization attribute. """ cas_login_url = Unicode( config=True, help="""The CAS URL to redirect unauthenticated users to.""") cas_logout_url = Unicode( config=True, help="""The CAS URL for logging out an authenticated user.""") cas_service_url = Unicode( allow_none=True, default_value=None, config=True, help= """The service URL the CAS server will redirect the browser back to on successful authentication.""" ) cas_client_ca_certs = Unicode( allow_none=True, default_value=None, config=True, help= """Path to CA certificates the CAS client will trust when validating a service ticket.""" ) cas_service_validate_url = Unicode( config=True, help="""The CAS endpoint for validating service tickets.""") cas_required_attribs = Set( help= "A set of attribute name and value tuples a user must have to be allowed access." ).tag(config=True) def get_handlers(self, app): return [ (r'/login', CASLoginHandler), (r'/logout', CASLogoutHandler), ] @gen.coroutine def authenticate(self, *args): raise NotImplementedError()
class ClearOutputPreprocessor(Preprocessor): """ Removes the output from all code cells in a notebook. """ remove_metadata_fields = Set({"collapsed", "scrolled"}).tag(config=True) def preprocess_cell(self, cell, resources, cell_index): """ Apply a transformation on each cell. See base.py for details. """ if cell.cell_type == "code": cell.outputs = [] cell.execution_count = None # Remove metadata associated with output if "metadata" in cell: for field in self.remove_metadata_fields: cell.metadata.pop(field, None) return cell, resources
class AddCellIndex(DsCreatePreprocessor): description = ''' AddCellIndex adds a metadata.index variable to a notebook and determines if a cell is a solution cell. This preprocessor is used primarily for ``--inline`` splits. ''' index = Int(0) solution_tags = Set( {'#__SOLUTION__', '#==SOLUTION==', '__SOLUTION__', '==SOLUTION=='}, help=("Tags indicating which cells are to be removed")).tag( config=True) def preprocess(self, nb, resources): nb_copy = deepcopy(nb) # Filter out cells that meet the conditions nb_copy.cells = [ self.preprocess_cell(cell, resources, index)[0] for index, cell in enumerate(nb_copy.cells) ] return nb_copy, resources def preprocess_cell(self, cell, resources, cell_index): """ No transformation is applied. """ lines = set(cell.source.split("\n")) if self.solution_tags.intersection(lines): cell['metadata']['solution'] = True else: cell['metadata']['solution'] = False cell['metadata']['index'] = self.index self.index += 1 return cell, resources
class NBViewer(Application): name = Unicode("NBViewer") aliases = Dict({ "base-url": "NBViewer.base_url", "binder-base-url": "NBViewer.binder_base_url", "cache-expiry-max": "NBViewer.cache_expiry_max", "cache-expiry-min": "NBViewer.cache_expiry_min", "config-file": "NBViewer.config_file", "content-security-policy": "NBViewer.content_security_policy", "default-format": "NBViewer.default_format", "frontpage": "NBViewer.frontpage", "host": "NBViewer.host", "ipywidgets-base-url": "NBViewer.ipywidgets_base_url", "jupyter-js-widgets-version": "NBViewer.jupyter_js_widgets_version", "jupyter-widgets-html-manager-version": "NBViewer.jupyter_widgets_html_manager_version", "localfiles": "NBViewer.localfiles", "log-level": "Application.log_level", "mathjax-url": "NBViewer.mathjax_url", "mc-threads": "NBViewer.mc_threads", "port": "NBViewer.port", "processes": "NBViewer.processes", "provider-rewrites": "NBViewer.provider_rewrites", "providers": "NBViewer.providers", "proxy-host": "NBViewer.proxy_host", "proxy-port": "NBViewer.proxy_port", "rate-limit": "NBViewer.rate_limit", "rate-limit-interval": "NBViewer.rate_limit_interval", "render-timeout": "NBViewer.render_timeout", "sslcert": "NBViewer.sslcert", "sslkey": "NBViewer.sslkey", "static-path": "NBViewer.static_path", "static-url-prefix": "NBViewer.static_url_prefix", "statsd-host": "NBViewer.statsd_host", "statsd-port": "NBViewer.statsd_port", "statsd-prefix": "NBViewer.statsd_prefix", "template-path": "NBViewer.template_path", "threads": "NBViewer.threads", }) flags = Dict({ "debug": ( { "Application": { "log_level": logging.DEBUG } }, "Set log-level to debug, for the most verbose logging.", ), "generate-config": ( { "NBViewer": { "generate_config": True } }, "Generate default config file.", ), "localfile-any-user": ( { "NBViewer": { "localfile_any_user": True } }, "Also serve files that are not readable by 'Other' on the local file system.", ), "localfile-follow-symlinks": ( { "NBViewer": { "localfile_follow_symlinks": True } }, "Resolve/follow symbolic links to their target file using realpath.", ), "no-cache": ({ "NBViewer": { "no_cache": True } }, "Do not cache results."), "no-check-certificate": ( { "NBViewer": { "no_check_certificate": True } }, "Do not validate SSL certificates.", ), "y": ( { "NBViewer": { "answer_yes": True } }, "Answer yes to any questions (e.g. confirm overwrite).", ), "yes": ( { "NBViewer": { "answer_yes": True } }, "Answer yes to any questions (e.g. confirm overwrite).", ), }) # Use this to insert custom configuration of handlers for NBViewer extensions handler_settings = Dict().tag(config=True) create_handler = Unicode( default_value="nbviewer.handlers.CreateHandler", help="The Tornado handler to use for creation via frontpage form.", ).tag(config=True) custom404_handler = Unicode( default_value="nbviewer.handlers.Custom404", help="The Tornado handler to use for rendering 404 templates.", ).tag(config=True) faq_handler = Unicode( default_value="nbviewer.handlers.FAQHandler", help= "The Tornado handler to use for rendering and viewing the FAQ section.", ).tag(config=True) gist_handler = Unicode( default_value="nbviewer.providers.gist.handlers.GistHandler", help= "The Tornado handler to use for viewing notebooks stored as GitHub Gists", ).tag(config=True) github_blob_handler = Unicode( default_value="nbviewer.providers.github.handlers.GitHubBlobHandler", help= "The Tornado handler to use for viewing notebooks stored as blobs on GitHub", ).tag(config=True) github_tree_handler = Unicode( default_value="nbviewer.providers.github.handlers.GitHubTreeHandler", help="The Tornado handler to use for viewing directory trees on GitHub", ).tag(config=True) github_user_handler = Unicode( default_value="nbviewer.providers.github.handlers.GitHubUserHandler", help= "The Tornado handler to use for viewing all of a user's repositories on GitHub.", ).tag(config=True) index_handler = Unicode( default_value="nbviewer.handlers.IndexHandler", help="The Tornado handler to use for rendering the frontpage section.", ).tag(config=True) local_handler = Unicode( default_value="nbviewer.providers.local.handlers.LocalFileHandler", help= "The Tornado handler to use for viewing notebooks found on a local filesystem", ).tag(config=True) url_handler = Unicode( default_value="nbviewer.providers.url.handlers.URLHandler", help= "The Tornado handler to use for viewing notebooks accessed via URL", ).tag(config=True) user_gists_handler = Unicode( default_value="nbviewer.providers.gist.handlers.UserGistsHandler", help= "The Tornado handler to use for viewing directory containing all of a user's Gists", ).tag(config=True) answer_yes = Bool( default_value=False, help="Answer yes to any questions (e.g. confirm overwrite).", ).tag(config=True) # base_url specified by the user base_url = Unicode(default_value="/", help="URL base for the server").tag(config=True) binder_base_url = Unicode( default_value="https://mybinder.org/v2", help="URL base for binder notebook execution service.", ).tag(config=True) cache_expiry_max = Int( default_value=2 * 60 * 60, help="Maximum cache expiry (seconds).").tag(config=True) cache_expiry_min = Int( default_value=10 * 60, help="Minimum cache expiry (seconds).").tag(config=True) client = Any().tag(config=True) @default("client") def _default_client(self): client = HTTPClientClass(log=self.log) client.cache = self.cache return client config_file = Unicode(default_value="nbviewer_config.py", help="The config file to load.").tag(config=True) content_security_policy = Unicode( default_value="connect-src 'none';", help="Content-Security-Policy header setting.", ).tag(config=True) default_format = Unicode( default_value="html", help="Format to use for legacy / URLs.").tag(config=True) frontpage = Unicode( default_value=FRONTPAGE_JSON, help="Path to json file containing frontpage content.", ).tag(config=True) generate_config = Bool( default_value=False, help="Generate default config file.").tag(config=True) host = Unicode(help="Run on the given interface.").tag(config=True) @default("host") def _default_host(self): return self.default_endpoint["host"] index = Any().tag(config=True) @default("index") def _load_index(self): if os.environ.get("NBINDEX_PORT"): self.log.info("Indexing notebooks") tcp_index = os.environ.get("NBINDEX_PORT") index_url = tcp_index.split("tcp://")[1] index_host, index_port = index_url.split(":") else: self.log.info("Not indexing notebooks") indexer = NoSearch() return indexer ipywidgets_base_url = Unicode( default_value="https://unpkg.com/", help="URL base for ipywidgets JS package.").tag(config=True) jupyter_js_widgets_version = Unicode( default_value="*", help="Version specifier for jupyter-js-widgets JS package.").tag( config=True) jupyter_widgets_html_manager_version = Unicode( default_value="*", help="Version specifier for @jupyter-widgets/html-manager JS package.", ).tag(config=True) localfile_any_user = Bool( default_value=False, help= "Also serve files that are not readable by 'Other' on the local file system.", ).tag(config=True) localfile_follow_symlinks = Bool( default_value=False, help= "Resolve/follow symbolic links to their target file using realpath.", ).tag(config=True) localfiles = Unicode( default_value="", help= "Allow to serve local files under /localfile/* this can be a security risk.", ).tag(config=True) mathjax_url = Unicode( default_value="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/", help="URL base for mathjax package.", ).tag(config=True) # cache frontpage links for the maximum allowed time max_cache_uris = Set().tag(config=True) @default("max_cache_uris") def _load_max_cache_uris(self): max_cache_uris = {""} for section in self.frontpage_setup["sections"]: for link in section["links"]: max_cache_uris.add("/" + link["target"]) return max_cache_uris mc_threads = Int( default_value=1, help="Number of threads to use for Async Memcache.").tag(config=True) no_cache = Bool(default_value=False, help="Do not cache results.").tag(config=True) no_check_certificate = Bool( default_value=False, help="Do not validate SSL certificates.").tag(config=True) port = Int(help="Run on the given port.").tag(config=True) @default("port") def _default_port(self): return self.default_endpoint["port"] processes = Int( default_value=0, help="Use processes instead of threads for rendering.").tag( config=True) provider_rewrites = List( trait=Unicode, default_value=default_rewrites, help="Full dotted package(s) that provide `uri_rewrites`.", ).tag(config=True) providers = List( trait=Unicode, default_value=default_providers, help="Full dotted package(s) that provide `default_handlers`.", ).tag(config=True) proxy_host = Unicode(default_value="", help="The proxy URL.").tag(config=True) proxy_port = Int(default_value=-1, help="The proxy port.").tag(config=True) rate_limit = Int( default_value=60, help= "Number of requests to allow in rate_limit_interval before limiting. Only requests that trigger a new render are counted.", ).tag(config=True) rate_limit_interval = Int( default_value=600, help="Interval (in seconds) for rate limiting.").tag(config=True) render_timeout = Int( default_value=15, help= "Time to wait for a render to complete before showing the 'Working...' page.", ).tag(config=True) sslcert = Unicode(help="Path to ssl .crt file.").tag(config=True) sslkey = Unicode(help="Path to ssl .key file.").tag(config=True) static_path = Unicode( default_value=os.environ.get("NBVIEWER_STATIC_PATH", ""), help="Custom path for loading additional static files.", ).tag(config=True) static_url_prefix = Unicode(default_value="/static/").tag(config=True) # Not exposed to end user for configuration, since needs to access base_url _static_url_prefix = Unicode() @default("_static_url_prefix") def _load_static_url_prefix(self): # Last '/' ensures that NBViewer still works regardless of whether user chooses e.g. '/static2/' or '/static2' as their custom prefix return url_path_join(self._base_url, self.static_url_prefix, "/") statsd_host = Unicode( default_value="", help="Host running statsd to send metrics to.").tag(config=True) statsd_port = Int( default_value=8125, help="Port on which statsd is listening for metrics on statsd_host.", ).tag(config=True) statsd_prefix = Unicode( default_value="nbviewer", help="Prefix to use for naming metrics sent to statsd.", ).tag(config=True) template_path = Unicode( default_value=os.environ.get("NBVIEWER_TEMPLATE_PATH", ""), help= "Custom template path for the nbviewer app (not rendered notebooks).", ).tag(config=True) threads = Int( default_value=1, help="Number of threads to use for rendering.").tag(config=True) # prefer the JupyterHub defined service prefix over the CLI @cached_property def _base_url(self): return os.getenv("JUPYTERHUB_SERVICE_PREFIX", self.base_url) @cached_property def cache(self): memcache_urls = os.environ.get("MEMCACHIER_SERVERS", os.environ.get("MEMCACHE_SERVERS")) # Handle linked Docker containers if os.environ.get("NBCACHE_PORT"): tcp_memcache = os.environ.get("NBCACHE_PORT") memcache_urls = tcp_memcache.split("tcp://")[1] if self.no_cache: self.log.info("Not using cache") cache = MockCache() elif pylibmc and memcache_urls: # setup memcache mc_pool = ThreadPoolExecutor(self.mc_threads) kwargs = dict(pool=mc_pool) username = os.environ.get("MEMCACHIER_USERNAME", "") password = os.environ.get("MEMCACHIER_PASSWORD", "") if username and password: kwargs["binary"] = True kwargs["username"] = username kwargs["password"] = password self.log.info("Using SASL memcache") else: self.log.info("Using plain memcache") cache = AsyncMultipartMemcache(memcache_urls.split(","), **kwargs) else: self.log.info("Using in-memory cache") cache = DummyAsyncCache() return cache @cached_property def default_endpoint(self): # check if JupyterHub service options are available to use as defaults if "JUPYTERHUB_SERVICE_URL" in os.environ: url = urlparse(os.environ["JUPYTERHUB_SERVICE_URL"]) default_host, default_port = url.hostname, url.port else: default_host, default_port = "0.0.0.0", 5000 return {"host": default_host, "port": default_port} @cached_property def env(self): env = Environment(loader=FileSystemLoader(self.template_paths), autoescape=True) env.filters["markdown"] = markdown.markdown try: git_data = git_info(here) except Exception as e: self.log.error("Failed to get git info: %s", e) git_data = {} else: git_data["msg"] = escape(git_data["msg"]) if self.no_cache: # force Jinja2 to recompile template every time env.globals.update(cache_size=0) env.globals.update( nrhead=nrhead, nrfoot=nrfoot, git_data=git_data, jupyter_info=jupyter_info(), len=len, ) return env @cached_property def fetch_kwargs(self): fetch_kwargs = dict(connect_timeout=10) if self.proxy_host: fetch_kwargs.update(proxy_host=self.proxy_host, proxy_port=self.proxy_port) self.log.info("Using web proxy {proxy_host}:{proxy_port}." "".format(**fetch_kwargs)) if self.no_check_certificate: fetch_kwargs.update(validate_cert=False) self.log.info("Not validating SSL certificates") return fetch_kwargs @cached_property def formats(self): return self.configure_formats() # load frontpage sections @cached_property def frontpage_setup(self): with io.open(self.frontpage, "r") as f: frontpage_setup = json.load(f) # check if the JSON has a 'sections' field, otherwise assume it is just a list of sessions, # and provide the defaults of the other fields if "sections" not in frontpage_setup: frontpage_setup = { "title": "nbviewer", "subtitle": "A simple way to share Jupyter notebooks", "show_input": True, "sections": frontpage_setup, } return frontpage_setup # Attribute inherited from traitlets.config.Application, automatically used to style logs # https://github.com/ipython/traitlets/blob/master/traitlets/config/application.py#L191 _log_formatter_cls = LogFormatter # Need Tornado LogFormatter for color logs, keys 'color' and 'end_color' in log_format # Observed traitlet inherited again from traitlets.config.Application # https://github.com/ipython/traitlets/blob/master/traitlets/config/application.py#L177 @default("log_level") def _log_level_default(self): return logging.INFO # Ditto the above: https://github.com/ipython/traitlets/blob/master/traitlets/config/application.py#L197 @default("log_format") def _log_format_default(self): """override default log format to include time and color, plus to always display the log level, not just when it's high""" return "%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s %(module)s:%(lineno)d]%(end_color)s %(message)s" # For consistency with JupyterHub logs @default("log_datefmt") def _log_datefmt_default(self): """Exclude date from default date format""" return "%Y-%m-%d %H:%M:%S" @cached_property def pool(self): if self.processes: pool = ProcessPoolExecutor(self.processes) else: pool = ThreadPoolExecutor(self.threads) return pool @cached_property def rate_limiter(self): rate_limiter = RateLimiter(limit=self.rate_limit, interval=self.rate_limit_interval, cache=self.cache) return rate_limiter @cached_property def static_paths(self): default_static_path = pjoin(here, "static") if self.static_path: self.log.info("Using custom static path {}".format( self.static_path)) static_paths = [self.static_path, default_static_path] else: static_paths = [default_static_path] return static_paths @cached_property def template_paths(self): default_template_path = pjoin(here, "templates") if self.template_path: self.log.info("Using custom template path {}".format( self.template_path)) template_paths = [self.template_path, default_template_path] else: template_paths = [default_template_path] return template_paths def configure_formats(self, formats=None): """ Format-specific configuration. """ if formats is None: formats = default_formats() # This would be better defined in a class self.config.HTMLExporter.template_file = "basic" self.config.SlidesExporter.template_file = "slides_reveal" self.config.TemplateExporter.template_path = [ os.path.join(os.path.dirname(__file__), "templates", "nbconvert") ] for key, format in formats.items(): exporter_cls = format.get("exporter", exporter_map[key]) if self.processes: # can't pickle exporter instances, formats[key]["exporter"] = exporter_cls else: formats[key]["exporter"] = exporter_cls(config=self.config, log=self.log) return formats def init_tornado_application(self): # handle handlers handler_names = dict( create_handler=self.create_handler, custom404_handler=self.custom404_handler, faq_handler=self.faq_handler, gist_handler=self.gist_handler, github_blob_handler=self.github_blob_handler, github_tree_handler=self.github_tree_handler, github_user_handler=self.github_user_handler, index_handler=self.index_handler, local_handler=self.local_handler, url_handler=self.url_handler, user_gists_handler=self.user_gists_handler, ) handler_kwargs = { "handler_names": handler_names, "handler_settings": self.handler_settings, } handlers = init_handlers(self.formats, self.providers, self._base_url, self.localfiles, **handler_kwargs) # NBConvert config self.config.NbconvertApp.fileext = "html" self.config.CSSHTMLHeaderTransformer.enabled = False # DEBUG env implies both autoreload and log-level if os.environ.get("DEBUG"): self.log.setLevel(logging.DEBUG) # input traitlets to settings settings = dict( # Allow FileFindHandler to load static directories from e.g. a Docker container allow_remote_access=True, base_url=self._base_url, binder_base_url=self.binder_base_url, cache=self.cache, cache_expiry_max=self.cache_expiry_max, cache_expiry_min=self.cache_expiry_min, client=self.client, config=self.config, content_security_policy=self.content_security_policy, default_format=self.default_format, fetch_kwargs=self.fetch_kwargs, formats=self.formats, frontpage_setup=self.frontpage_setup, google_analytics_id=os.getenv("GOOGLE_ANALYTICS_ID"), gzip=True, hub_api_token=os.getenv("JUPYTERHUB_API_TOKEN"), hub_api_url=os.getenv("JUPYTERHUB_API_URL"), hub_base_url=os.getenv("JUPYTERHUB_BASE_URL"), index=self.index, ipywidgets_base_url=self.ipywidgets_base_url, jinja2_env=self.env, jupyter_js_widgets_version=self.jupyter_js_widgets_version, jupyter_widgets_html_manager_version=self. jupyter_widgets_html_manager_version, localfile_any_user=self.localfile_any_user, localfile_follow_symlinks=self.localfile_follow_symlinks, localfile_path=os.path.abspath(self.localfiles), log=self.log, log_function=log_request, mathjax_url=self.mathjax_url, max_cache_uris=self.max_cache_uris, pool=self.pool, provider_rewrites=self.provider_rewrites, providers=self.providers, rate_limiter=self.rate_limiter, render_timeout=self.render_timeout, static_handler_class=StaticFileHandler, # FileFindHandler expects list of static paths, so self.static_path*s* is correct static_path=self.static_paths, static_url_prefix=self._static_url_prefix, statsd_host=self.statsd_host, statsd_port=self.statsd_port, statsd_prefix=self.statsd_prefix, ) if self.localfiles: self.log.warning( "Serving local notebooks in %s, this can be a security risk", self.localfiles, ) # create the app self.tornado_application = web.Application(handlers, **settings) def init_logging(self): # Note that we inherit a self.log attribute from traitlets.config.Application # https://github.com/ipython/traitlets/blob/master/traitlets/config/application.py#L209 # as well as a log_level attribute # https://github.com/ipython/traitlets/blob/master/traitlets/config/application.py#L177 # This prevents double log messages because tornado use a root logger that # self.log is a child of. The logging module dispatches log messages to a log # and all of its ancestors until propagate is set to False. self.log.propagate = False tornado_log = logging.getLogger("tornado") # hook up tornado's loggers to our app handlers for log in (app_log, access_log, tornado_log, curl_log): # ensure all log statements identify the application they come from log.name = self.log.name log.parent = self.log log.propagate = True log.setLevel(self.log_level) # disable curl debug, which logs all headers, info for upstream requests, which is TOO MUCH curl_log.setLevel(max(self.log_level, logging.INFO)) # Mostly copied from JupyterHub because if it isn't broken then don't fix it. def write_config_file(self): """Write our default config to a .py config file""" config_file_dir = os.path.dirname(os.path.abspath(self.config_file)) if not os.path.isdir(config_file_dir): self.exit( "{} does not exist. The destination directory must exist before generating config file." .format(config_file_dir)) if os.path.exists(self.config_file) and not self.answer_yes: answer = "" def ask(): prompt = "Overwrite %s with default config? [y/N]" % self.config_file try: return input(prompt).lower() or "n" except KeyboardInterrupt: print("") # empty line return "n" answer = ask() while not answer.startswith(("y", "n")): print("Please answer 'yes' or 'no'") answer = ask() if answer.startswith("n"): self.exit("Not overwriting config file with default.") # Inherited method from traitlets.config.Application config_text = self.generate_config_file() if isinstance(config_text, bytes): config_text = config_text.decode("utf8") print("Writing default config to: %s" % self.config_file) with open(self.config_file, mode="w") as f: f.write(config_text) self.exit("Wrote default config file.") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # parse command line with catch_config_error from traitlets.config.Application super().initialize(*args, **kwargs) if self.generate_config: self.write_config_file() # Inherited method from traitlets.config.Application self.load_config_file(self.config_file) self.init_logging() self.init_tornado_application()
class JupyterHub(Application): """An Application for starting a Multi-User Jupyter Notebook server.""" name = 'jupyterhub' version = jupyterhub.__version__ description = """Start a multi-user Jupyter Notebook server Spawns a configurable-http-proxy and multi-user Hub, which authenticates users and spawns single-user Notebook servers on behalf of users. """ examples = """ generate default config file: jupyterhub --generate-config -f /etc/jupyterhub/jupyterhub.py spawn the server on 10.0.1.2:443 with https: jupyterhub --ip 10.0.1.2 --port 443 --ssl-key my_ssl.key --ssl-cert my_ssl.cert """ aliases = Dict(aliases) flags = Dict(flags) subcommands = { 'token': (NewToken, "Generate an API token for a user") } classes = List([ Spawner, LocalProcessSpawner, Authenticator, PAMAuthenticator, ]) config_file = Unicode('jupyterhub_config.py', config=True, help="The config file to load", ) generate_config = Bool(False, config=True, help="Generate default config file", ) answer_yes = Bool(False, config=True, help="Answer yes to any questions (e.g. confirm overwrite)" ) pid_file = Unicode('', config=True, help="""File to write PID Useful for daemonizing jupyterhub. """ ) cookie_max_age_days = Float(14, config=True, help="""Number of days for a login cookie to be valid. Default is two weeks. """ ) last_activity_interval = Integer(300, config=True, help="Interval (in seconds) at which to update last-activity timestamps." ) proxy_check_interval = Integer(30, config=True, help="Interval (in seconds) at which to check if the proxy is running." ) data_files_path = Unicode(DATA_FILES_PATH, config=True, help="The location of jupyterhub data files (e.g. /usr/local/share/jupyter/hub)" ) template_paths = List( config=True, help="Paths to search for jinja templates.", ) def _template_paths_default(self): return [os.path.join(self.data_files_path, 'templates')] ssl_key = Unicode('', config=True, help="""Path to SSL key file for the public facing interface of the proxy Use with ssl_cert """ ) ssl_cert = Unicode('', config=True, help="""Path to SSL certificate file for the public facing interface of the proxy Use with ssl_key """ ) ip = Unicode('', config=True, help="The public facing ip of the proxy" ) port = Integer(8000, config=True, help="The public facing port of the proxy" ) base_url = URLPrefix('/', config=True, help="The base URL of the entire application" ) jinja_environment_options = Dict(config=True, help="Supply extra arguments that will be passed to Jinja environment." ) proxy_cmd = Command('configurable-http-proxy', config=True, help="""The command to start the http proxy. Only override if configurable-http-proxy is not on your PATH """ ) debug_proxy = Bool(False, config=True, help="show debug output in configurable-http-proxy") proxy_auth_token = Unicode(config=True, help="""The Proxy Auth token. Loaded from the CONFIGPROXY_AUTH_TOKEN env variable by default. """ ) def _proxy_auth_token_default(self): token = os.environ.get('CONFIGPROXY_AUTH_TOKEN', None) if not token: self.log.warn('\n'.join([ "", "Generating CONFIGPROXY_AUTH_TOKEN. Restarting the Hub will require restarting the proxy.", "Set CONFIGPROXY_AUTH_TOKEN env or JupyterHub.proxy_auth_token config to avoid this message.", "", ])) token = orm.new_token() return token proxy_api_ip = Unicode('localhost', config=True, help="The ip for the proxy API handlers" ) proxy_api_port = Integer(config=True, help="The port for the proxy API handlers" ) def _proxy_api_port_default(self): return self.port + 1 hub_port = Integer(8081, config=True, help="The port for this process" ) hub_ip = Unicode('localhost', config=True, help="The ip for this process" ) hub_prefix = URLPrefix('/hub/', config=True, help="The prefix for the hub server. Must not be '/'" ) def _hub_prefix_default(self): return url_path_join(self.base_url, '/hub/') def _hub_prefix_changed(self, name, old, new): if new == '/': raise TraitError("'/' is not a valid hub prefix") if not new.startswith(self.base_url): self.hub_prefix = url_path_join(self.base_url, new) cookie_secret = Bytes(config=True, env='JPY_COOKIE_SECRET', help="""The cookie secret to use to encrypt cookies. Loaded from the JPY_COOKIE_SECRET env variable by default. """ ) cookie_secret_file = Unicode('jupyterhub_cookie_secret', config=True, help="""File in which to store the cookie secret.""" ) authenticator_class = Type(PAMAuthenticator, Authenticator, config=True, help="""Class for authenticating users. This should be a class with the following form: - constructor takes one kwarg: `config`, the IPython config object. - is a tornado.gen.coroutine - returns username on success, None on failure - takes two arguments: (handler, data), where `handler` is the calling web.RequestHandler, and `data` is the POST form data from the login page. """ ) authenticator = Instance(Authenticator) def _authenticator_default(self): return self.authenticator_class(parent=self, db=self.db) # class for spawning single-user servers spawner_class = Type(LocalProcessSpawner, Spawner, config=True, help="""The class to use for spawning single-user servers. Should be a subclass of Spawner. """ ) db_url = Unicode('sqlite:///jupyterhub.sqlite', config=True, help="url for the database. e.g. `sqlite:///jupyterhub.sqlite`" ) def _db_url_changed(self, name, old, new): if '://' not in new: # assume sqlite, if given as a plain filename self.db_url = 'sqlite:///%s' % new db_kwargs = Dict(config=True, help="""Include any kwargs to pass to the database connection. See sqlalchemy.create_engine for details. """ ) reset_db = Bool(False, config=True, help="Purge and reset the database." ) debug_db = Bool(False, config=True, help="log all database transactions. This has A LOT of output" ) session_factory = Any() admin_access = Bool(False, config=True, help="""Grant admin users permission to access single-user servers. Users should be properly informed if this is enabled. """ ) admin_users = Set(config=True, help="""DEPRECATED, use Authenticator.admin_users instead.""" ) tornado_settings = Dict(config=True) cleanup_servers = Bool(True, config=True, help="""Whether to shutdown single-user servers when the Hub shuts down. Disable if you want to be able to teardown the Hub while leaving the single-user servers running. If both this and cleanup_proxy are False, sending SIGINT to the Hub will only shutdown the Hub, leaving everything else running. The Hub should be able to resume from database state. """ ) cleanup_proxy = Bool(True, config=True, help="""Whether to shutdown the proxy when the Hub shuts down. Disable if you want to be able to teardown the Hub while leaving the proxy running. Only valid if the proxy was starting by the Hub process. If both this and cleanup_servers are False, sending SIGINT to the Hub will only shutdown the Hub, leaving everything else running. The Hub should be able to resume from database state. """ ) handlers = List() _log_formatter_cls = CoroutineLogFormatter http_server = None proxy_process = None io_loop = None def _log_level_default(self): return logging.INFO def _log_datefmt_default(self): """Exclude date from default date format""" return "%Y-%m-%d %H:%M:%S" def _log_format_default(self): """override default log format to include time""" return "%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s %(module)s:%(lineno)d]%(end_color)s %(message)s" extra_log_file = Unicode( "", config=True, help="Set a logging.FileHandler on this file." ) extra_log_handlers = List( Instance(logging.Handler), config=True, help="Extra log handlers to set on JupyterHub logger", ) def init_logging(self): # This prevents double log messages because tornado use a root logger that # self.log is a child of. The logging module dipatches log messages to a log # and all of its ancenstors until propagate is set to False. self.log.propagate = False if self.extra_log_file: self.extra_log_handlers.append( logging.FileHandler(self.extra_log_file) ) _formatter = self._log_formatter_cls( fmt=self.log_format, datefmt=self.log_datefmt, ) for handler in self.extra_log_handlers: if handler.formatter is None: handler.setFormatter(_formatter) self.log.addHandler(handler) # hook up tornado 3's loggers to our app handlers for log in (app_log, access_log, gen_log): # ensure all log statements identify the application they come from log.name = self.log.name logger = logging.getLogger('tornado') logger.propagate = True logger.parent = self.log logger.setLevel(self.log.level) def init_ports(self): if self.hub_port == self.port: raise TraitError("The hub and proxy cannot both listen on port %i" % self.port) if self.hub_port == self.proxy_api_port: raise TraitError("The hub and proxy API cannot both listen on port %i" % self.hub_port) if self.proxy_api_port == self.port: raise TraitError("The proxy's public and API ports cannot both be %i" % self.port) @staticmethod def add_url_prefix(prefix, handlers): """add a url prefix to handlers""" for i, tup in enumerate(handlers): lis = list(tup) lis[0] = url_path_join(prefix, tup[0]) handlers[i] = tuple(lis) return handlers def init_handlers(self): h = [] h.extend(handlers.default_handlers) h.extend(apihandlers.default_handlers) # load handlers from the authenticator h.extend(self.authenticator.get_handlers(self)) self.handlers = self.add_url_prefix(self.hub_prefix, h) # some extra handlers, outside hub_prefix self.handlers.extend([ (r"%s" % self.hub_prefix.rstrip('/'), web.RedirectHandler, { "url": self.hub_prefix, "permanent": False, } ), (r"(?!%s).*" % self.hub_prefix, handlers.PrefixRedirectHandler), (r'(.*)', handlers.Template404), ]) def _check_db_path(self, path): """More informative log messages for failed filesystem access""" path = os.path.abspath(path) parent, fname = os.path.split(path) user = getuser() if not os.path.isdir(parent): self.log.error("Directory %s does not exist", parent) if os.path.exists(parent) and not os.access(parent, os.W_OK): self.log.error("%s cannot create files in %s", user, parent) if os.path.exists(path) and not os.access(path, os.W_OK): self.log.error("%s cannot edit %s", user, path) def init_secrets(self): trait_name = 'cookie_secret' trait = self.traits()[trait_name] env_name = trait.get_metadata('env') secret_file = os.path.abspath( os.path.expanduser(self.cookie_secret_file) ) secret = self.cookie_secret secret_from = 'config' # load priority: 1. config, 2. env, 3. file if not secret and os.environ.get(env_name): secret_from = 'env' self.log.info("Loading %s from env[%s]", trait_name, env_name) secret = binascii.a2b_hex(os.environ[env_name]) if not secret and os.path.exists(secret_file): secret_from = 'file' perm = os.stat(secret_file).st_mode if perm & 0o077: self.log.error("Bad permissions on %s", secret_file) else: self.log.info("Loading %s from %s", trait_name, secret_file) with open(secret_file) as f: b64_secret = f.read() try: secret = binascii.a2b_base64(b64_secret) except Exception as e: self.log.error("%s does not contain b64 key: %s", secret_file, e) if not secret: secret_from = 'new' self.log.debug("Generating new %s", trait_name) secret = os.urandom(SECRET_BYTES) if secret_file and secret_from == 'new': # if we generated a new secret, store it in the secret_file self.log.info("Writing %s to %s", trait_name, secret_file) b64_secret = binascii.b2a_base64(secret).decode('ascii') with open(secret_file, 'w') as f: f.write(b64_secret) try: os.chmod(secret_file, 0o600) except OSError: self.log.warn("Failed to set permissions on %s", secret_file) # store the loaded trait value self.cookie_secret = secret # thread-local storage of db objects _local = Instance(threading.local, ()) @property def db(self): if not hasattr(self._local, 'db'): self._local.db = scoped_session(self.session_factory)() return self._local.db @property def hub(self): if not getattr(self._local, 'hub', None): q = self.db.query(orm.Hub) assert q.count() <= 1 self._local.hub = q.first() return self._local.hub @hub.setter def hub(self, hub): self._local.hub = hub @property def proxy(self): if not getattr(self._local, 'proxy', None): q = self.db.query(orm.Proxy) assert q.count() <= 1 p = self._local.proxy = q.first() if p: p.auth_token = self.proxy_auth_token return self._local.proxy @proxy.setter def proxy(self, proxy): self._local.proxy = proxy def init_db(self): """Create the database connection""" self.log.debug("Connecting to db: %s", self.db_url) try: self.session_factory = orm.new_session_factory( self.db_url, reset=self.reset_db, echo=self.debug_db, **self.db_kwargs ) # trigger constructing thread local db property _ = self.db except OperationalError as e: self.log.error("Failed to connect to db: %s", self.db_url) self.log.debug("Database error was:", exc_info=True) if self.db_url.startswith('sqlite:///'): self._check_db_path(self.db_url.split(':///', 1)[1]) self.exit(1) def init_hub(self): """Load the Hub config into the database""" self.hub = self.db.query(orm.Hub).first() if self.hub is None: self.hub = orm.Hub( server=orm.Server( ip=self.hub_ip, port=self.hub_port, base_url=self.hub_prefix, cookie_name='jupyter-hub-token', ) ) self.db.add(self.hub) else: server = self.hub.server server.ip = self.hub_ip server.port = self.hub_port server.base_url = self.hub_prefix self.db.commit() @gen.coroutine def init_users(self): """Load users into and from the database""" db = self.db if self.admin_users and not self.authenticator.admin_users: self.log.warn( "\nJupyterHub.admin_users is deprecated." "\nUse Authenticator.admin_users instead." ) self.authenticator.admin_users = self.admin_users admin_users = self.authenticator.admin_users if not admin_users: # add current user as admin if there aren't any others admins = db.query(orm.User).filter(orm.User.admin==True) if admins.first() is None: admin_users.add(getuser()) new_users = [] for name in admin_users: # ensure anyone specified as admin in config is admin in db user = orm.User.find(db, name) if user is None: user = orm.User(name=name, admin=True) new_users.append(user) db.add(user) else: user.admin = True # the admin_users config variable will never be used after this point. # only the database values will be referenced. whitelist = self.authenticator.whitelist if not whitelist: self.log.info("Not using whitelist. Any authenticated user will be allowed.") # add whitelisted users to the db for name in whitelist: user = orm.User.find(db, name) if user is None: user = orm.User(name=name) new_users.append(user) db.add(user) if whitelist: # fill the whitelist with any users loaded from the db, # so we are consistent in both directions. # This lets whitelist be used to set up initial list, # but changes to the whitelist can occur in the database, # and persist across sessions. for user in db.query(orm.User): whitelist.add(user.name) # The whitelist set and the users in the db are now the same. # From this point on, any user changes should be done simultaneously # to the whitelist set and user db, unless the whitelist is empty (all users allowed). db.commit() for user in new_users: yield gen.maybe_future(self.authenticator.add_user(user)) db.commit() @gen.coroutine def init_spawners(self): db = self.db user_summaries = [''] def _user_summary(user): parts = ['{0: >8}'.format(user.name)] if user.admin: parts.append('admin') if user.server: parts.append('running at %s' % user.server) return ' '.join(parts) @gen.coroutine def user_stopped(user): status = yield user.spawner.poll() self.log.warn("User %s server stopped with exit code: %s", user.name, status, ) yield self.proxy.delete_user(user) yield user.stop() for user in db.query(orm.User): if not user.state: # without spawner state, server isn't valid user.server = None user_summaries.append(_user_summary(user)) continue self.log.debug("Loading state for %s from db", user.name) user.spawner = spawner = self.spawner_class( user=user, hub=self.hub, config=self.config, db=self.db, ) status = yield spawner.poll() if status is None: self.log.info("%s still running", user.name) spawner.add_poll_callback(user_stopped, user) spawner.start_polling() else: # user not running. This is expected if server is None, # but indicates the user's server died while the Hub wasn't running # if user.server is defined. log = self.log.warn if user.server else self.log.debug log("%s not running.", user.name) user.server = None user_summaries.append(_user_summary(user)) self.log.debug("Loaded users: %s", '\n'.join(user_summaries)) db.commit() def init_proxy(self): """Load the Proxy config into the database""" self.proxy = self.db.query(orm.Proxy).first() if self.proxy is None: self.proxy = orm.Proxy( public_server=orm.Server(), api_server=orm.Server(), ) self.db.add(self.proxy) self.db.commit() self.proxy.auth_token = self.proxy_auth_token # not persisted self.proxy.log = self.log self.proxy.public_server.ip = self.ip self.proxy.public_server.port = self.port self.proxy.api_server.ip = self.proxy_api_ip self.proxy.api_server.port = self.proxy_api_port self.proxy.api_server.base_url = '/api/routes/' self.db.commit() @gen.coroutine def start_proxy(self): """Actually start the configurable-http-proxy""" # check for proxy if self.proxy.public_server.is_up() or self.proxy.api_server.is_up(): # check for *authenticated* access to the proxy (auth token can change) try: yield self.proxy.get_routes() except (HTTPError, OSError, socket.error) as e: if isinstance(e, HTTPError) and e.code == 403: msg = "Did CONFIGPROXY_AUTH_TOKEN change?" else: msg = "Is something else using %s?" % self.proxy.public_server.bind_url self.log.error("Proxy appears to be running at %s, but I can't access it (%s)\n%s", self.proxy.public_server.bind_url, e, msg) self.exit(1) return else: self.log.info("Proxy already running at: %s", self.proxy.public_server.bind_url) self.proxy_process = None return env = os.environ.copy() env['CONFIGPROXY_AUTH_TOKEN'] = self.proxy.auth_token cmd = self.proxy_cmd + [ '--ip', self.proxy.public_server.ip, '--port', str(self.proxy.public_server.port), '--api-ip', self.proxy.api_server.ip, '--api-port', str(self.proxy.api_server.port), '--default-target', self.hub.server.host, ] if self.debug_proxy: cmd.extend(['--log-level', 'debug']) if self.ssl_key: cmd.extend(['--ssl-key', self.ssl_key]) if self.ssl_cert: cmd.extend(['--ssl-cert', self.ssl_cert]) self.log.info("Starting proxy @ %s", self.proxy.public_server.bind_url) self.log.debug("Proxy cmd: %s", cmd) try: self.proxy_process = Popen(cmd, env=env) except FileNotFoundError as e: self.log.error( "Failed to find proxy %r\n" "The proxy can be installed with `npm install -g configurable-http-proxy`" % self.proxy_cmd ) self.exit(1) def _check(): status = self.proxy_process.poll() if status is not None: e = RuntimeError("Proxy failed to start with exit code %i" % status) # py2-compatible `raise e from None` e.__cause__ = None raise e for server in (self.proxy.public_server, self.proxy.api_server): for i in range(10): _check() try: yield server.wait_up(1) except TimeoutError: continue else: break yield server.wait_up(1) self.log.debug("Proxy started and appears to be up") @gen.coroutine def check_proxy(self): if self.proxy_process.poll() is None: return self.log.error("Proxy stopped with exit code %r", 'unknown' if self.proxy_process is None else self.proxy_process.poll() ) yield self.start_proxy() self.log.info("Setting up routes on new proxy") yield self.proxy.add_all_users() self.log.info("New proxy back up, and good to go") def init_tornado_settings(self): """Set up the tornado settings dict.""" base_url = self.hub.server.base_url jinja_env = Environment( loader=FileSystemLoader(self.template_paths), **self.jinja_environment_options ) login_url = self.authenticator.login_url(base_url) logout_url = self.authenticator.logout_url(base_url) # if running from git, disable caching of require.js # otherwise cache based on server start time parent = os.path.dirname(os.path.dirname(jupyterhub.__file__)) if os.path.isdir(os.path.join(parent, '.git')): version_hash = '' else: version_hash=datetime.now().strftime("%Y%m%d%H%M%S"), settings = dict( log_function=log_request, config=self.config, log=self.log, db=self.db, proxy=self.proxy, hub=self.hub, admin_users=self.authenticator.admin_users, admin_access=self.admin_access, authenticator=self.authenticator, spawner_class=self.spawner_class, base_url=self.base_url, cookie_secret=self.cookie_secret, cookie_max_age_days=self.cookie_max_age_days, login_url=login_url, logout_url=logout_url, static_path=os.path.join(self.data_files_path, 'static'), static_url_prefix=url_path_join(self.hub.server.base_url, 'static/'), static_handler_class=CacheControlStaticFilesHandler, template_path=self.template_paths, jinja2_env=jinja_env, version_hash=version_hash, ) # allow configured settings to have priority settings.update(self.tornado_settings) self.tornado_settings = settings def init_tornado_application(self): """Instantiate the tornado Application object""" self.tornado_application = web.Application(self.handlers, **self.tornado_settings) def write_pid_file(self): pid = os.getpid() if self.pid_file: self.log.debug("Writing PID %i to %s", pid, self.pid_file) with open(self.pid_file, 'w') as f: f.write('%i' % pid) @gen.coroutine @catch_config_error def initialize(self, *args, **kwargs): super().initialize(*args, **kwargs) if self.generate_config or self.subapp: return self.load_config_file(self.config_file) self.init_logging() if 'JupyterHubApp' in self.config: self.log.warn("Use JupyterHub in config, not JupyterHubApp. Outdated config:\n%s", '\n'.join('JupyterHubApp.{key} = {value!r}'.format(key=key, value=value) for key, value in self.config.JupyterHubApp.items() ) ) cfg = self.config.copy() cfg.JupyterHub.merge(cfg.JupyterHubApp) self.update_config(cfg) self.write_pid_file() self.init_ports() self.init_secrets() self.init_db() self.init_hub() self.init_proxy() yield self.init_users() yield self.init_spawners() self.init_handlers() self.init_tornado_settings() self.init_tornado_application() @gen.coroutine def cleanup(self): """Shutdown our various subprocesses and cleanup runtime files.""" futures = [] if self.cleanup_servers: self.log.info("Cleaning up single-user servers...") # request (async) process termination for user in self.db.query(orm.User): if user.spawner is not None: futures.append(user.stop()) else: self.log.info("Leaving single-user servers running") # clean up proxy while SUS are shutting down if self.cleanup_proxy: if self.proxy_process: self.log.info("Cleaning up proxy[%i]...", self.proxy_process.pid) if self.proxy_process.poll() is None: try: self.proxy_process.terminate() except Exception as e: self.log.error("Failed to terminate proxy process: %s", e) else: self.log.info("I didn't start the proxy, I can't clean it up") else: self.log.info("Leaving proxy running") # wait for the requests to stop finish: for f in futures: try: yield f except Exception as e: self.log.error("Failed to stop user: %s", e) self.db.commit() if self.pid_file and os.path.exists(self.pid_file): self.log.info("Cleaning up PID file %s", self.pid_file) os.remove(self.pid_file) # finally stop the loop once we are all cleaned up self.log.info("...done") def write_config_file(self): """Write our default config to a .py config file""" if os.path.exists(self.config_file) and not self.answer_yes: answer = '' def ask(): prompt = "Overwrite %s with default config? [y/N]" % self.config_file try: return input(prompt).lower() or 'n' except KeyboardInterrupt: print('') # empty line return 'n' answer = ask() while not answer.startswith(('y', 'n')): print("Please answer 'yes' or 'no'") answer = ask() if answer.startswith('n'): return config_text = self.generate_config_file() if isinstance(config_text, bytes): config_text = config_text.decode('utf8') print("Writing default config to: %s" % self.config_file) with open(self.config_file, mode='w') as f: f.write(config_text) @gen.coroutine def update_last_activity(self): """Update User.last_activity timestamps from the proxy""" routes = yield self.proxy.get_routes() for prefix, route in routes.items(): if 'user' not in route: # not a user route, ignore it continue user = orm.User.find(self.db, route['user']) if user is None: self.log.warn("Found no user for route: %s", route) continue try: dt = datetime.strptime(route['last_activity'], ISO8601_ms) except Exception: dt = datetime.strptime(route['last_activity'], ISO8601_s) user.last_activity = max(user.last_activity, dt) self.db.commit() yield self.proxy.check_routes(routes) @gen.coroutine def start(self): """Start the whole thing""" self.io_loop = loop = IOLoop.current() if self.subapp: self.subapp.start() loop.stop() return if self.generate_config: self.write_config_file() loop.stop() return # start the webserver self.http_server = tornado.httpserver.HTTPServer(self.tornado_application, xheaders=True) try: self.http_server.listen(self.hub_port, address=self.hub_ip) except Exception: self.log.error("Failed to bind hub to %s", self.hub.server.bind_url) raise else: self.log.info("Hub API listening on %s", self.hub.server.bind_url) # start the proxy try: yield self.start_proxy() except Exception as e: self.log.critical("Failed to start proxy", exc_info=True) self.exit(1) return loop.add_callback(self.proxy.add_all_users) if self.proxy_process: # only check / restart the proxy if we started it in the first place. # this means a restarted Hub cannot restart a Proxy that its # predecessor started. pc = PeriodicCallback(self.check_proxy, 1e3 * self.proxy_check_interval) pc.start() if self.last_activity_interval: pc = PeriodicCallback(self.update_last_activity, 1e3 * self.last_activity_interval) pc.start() self.log.info("JupyterHub is now running at %s", self.proxy.public_server.url) # register cleanup on both TERM and INT atexit.register(self.atexit) self.init_signal() def init_signal(self): signal.signal(signal.SIGTERM, self.sigterm) def sigterm(self, signum, frame): self.log.critical("Received SIGTERM, shutting down") self.io_loop.stop() self.atexit() _atexit_ran = False def atexit(self): """atexit callback""" if self._atexit_ran: return self._atexit_ran = True # run the cleanup step (in a new loop, because the interrupted one is unclean) IOLoop.clear_current() loop = IOLoop() loop.make_current() loop.run_sync(self.cleanup) def stop(self): if not self.io_loop: return if self.http_server: if self.io_loop._running: self.io_loop.add_callback(self.http_server.stop) else: self.http_server.stop() self.io_loop.add_callback(self.io_loop.stop) @gen.coroutine def launch_instance_async(self, argv=None): try: yield self.initialize(argv) yield self.start() except Exception as e: self.log.exception("") self.exit(1) @classmethod def launch_instance(cls, argv=None): self = cls.instance() loop = IOLoop.current() loop.add_callback(self.launch_instance_async, argv) try: loop.start() except KeyboardInterrupt: print("\nInterrupted")
class GitLabOAuthenticator(OAuthenticator): login_service = "GitLab" client_id_env = 'GITLAB_CLIENT_ID' client_secret_env = 'GITLAB_CLIENT_SECRET' login_handler = GitLabLoginHandler gitlab_group_whitelist = Set( config=True, help="Automatically whitelist members of selected groups", ) @gen.coroutine def authenticate(self, handler, data=None): code = handler.get_argument("code") # TODO: Configure the curl_httpclient for tornado http_client = AsyncHTTPClient() # Exchange the OAuth code for a GitLab Access Token # # See: https://github.com/gitlabhq/gitlabhq/blob/master/doc/api/oauth2.md # GitLab specifies a POST request yet requires URL parameters params = dict( client_id=self.client_id, client_secret=self.client_secret, code=code, grant_type="authorization_code", redirect_uri=self.get_callback_url(handler), ) validate_server_cert = self.validate_server_cert url = url_concat("%s/oauth/token" % GITLAB_HOST, params) req = HTTPRequest( url, method="POST", headers={"Accept": "application/json"}, validate_cert=validate_server_cert, body='' # Body is required for a POST... ) resp = yield http_client.fetch(req) resp_json = json.loads(resp.body.decode('utf8', 'replace')) access_token = resp_json['access_token'] # Determine who the logged in user is req = HTTPRequest("%s/user" % GITLAB_API, method="GET", validate_cert=validate_server_cert, headers=_api_headers(access_token)) resp = yield http_client.fetch(req) resp_json = json.loads(resp.body.decode('utf8', 'replace')) username = resp_json["username"] user_id = resp_json["id"] is_admin = resp_json.get("is_admin", False) # Check if user is a member of any whitelisted organizations. # This check is performed here, as it requires `access_token`. if self.gitlab_group_whitelist: user_in_group = yield self._check_group_whitelist( username, user_id, is_admin, access_token) if not user_in_group: self.log.warning("%s not in group whitelist", username) return None return { 'name': username, 'auth_state': { 'access_token': access_token, 'gitlab_user': resp_json, } } @gen.coroutine def _check_group_whitelist(self, username, user_id, is_admin, access_token): http_client = AsyncHTTPClient() headers = _api_headers(access_token) # Check if we are a member of each group in the whitelist for group in map(url_escape, self.gitlab_group_whitelist): url = "%s/groups/%s/members/%d" % (GITLAB_API, group, user_id) req = HTTPRequest(url, method="GET", headers=headers) resp = yield http_client.fetch(req, raise_error=False) if resp.code == 200: return True # user _is_ in group return False
class BitbucketOAuthenticator(OAuthenticator): _deprecated_oauth_aliases = { "team_whitelist": ("allowed_teams", "0.12.0"), **OAuthenticator._deprecated_oauth_aliases, } login_service = "Bitbucket" client_id_env = 'BITBUCKET_CLIENT_ID' client_secret_env = 'BITBUCKET_CLIENT_SECRET' @default("authorize_url") def _authorize_url_default(self): return "https://bitbucket.org/site/oauth2/authorize" @default("token_url") def _token_url_default(self): return "https://bitbucket.org/site/oauth2/access_token" team_whitelist = Set( help="Deprecated, use `BitbucketOAuthenticator.allowed_teams`", config=True, ) allowed_teams = Set(config=True, help="Automatically allow members of selected teams") headers = { "Accept": "application/json", "User-Agent": "JupyterHub", "Authorization": "Bearer {}", } async def authenticate(self, handler, data=None): code = handler.get_argument("code") params = dict( client_id=self.client_id, client_secret=self.client_secret, grant_type="authorization_code", code=code, redirect_uri=self.get_callback_url(handler), ) url = url_concat("https://bitbucket.org/site/oauth2/access_token", params) bb_header = { "Content-Type": "application/x-www-form-urlencoded;charset=utf-8" } req = HTTPRequest( url, method="POST", auth_username=self.client_id, auth_password=self.client_secret, body=urllib.parse.urlencode(params).encode('utf-8'), headers=bb_header, ) resp_json = await self.fetch(req) access_token = resp_json['access_token'] # Determine who the logged in user is req = HTTPRequest( "https://api.bitbucket.org/2.0/user", method="GET", headers=_api_headers(access_token), ) resp_json = await self.fetch(req) username = resp_json["username"] # Check if user is a member of any allowed teams. # This check is performed here, as the check requires `access_token`. if self.allowed_teams: user_in_team = await self._check_membership_allowed_teams( username, access_token) if not user_in_team: self.log.warning("%s not in team allowed list of users", username) return None return { 'name': username, 'auth_state': { 'access_token': access_token, 'bitbucket_user': resp_json }, } async def _check_membership_allowed_teams(self, username, access_token): headers = _api_headers(access_token) # We verify the team membership by calling teams endpoint. next_page = url_concat("https://api.bitbucket.org/2.0/teams", {'role': 'member'}) while next_page: req = HTTPRequest(next_page, method="GET", headers=headers) resp_json = await self.fetch(req) next_page = resp_json.get('next', None) user_teams = set( [entry["username"] for entry in resp_json["values"]]) # check if any of the organizations seen thus far are in the allowed list if len(self.allowed_teams & user_teams) > 0: return True return False
class Widget(LoggingConfigurable): #------------------------------------------------------------------------- # Class attributes #------------------------------------------------------------------------- _widget_construction_callback = None widgets = {} widget_types = {} @staticmethod def on_widget_constructed(callback): """Registers a callback to be called when a widget is constructed. The callback must have the following signature: callback(widget)""" Widget._widget_construction_callback = callback @staticmethod def _call_widget_constructed(widget): """Static method, called when a widget is constructed.""" if Widget._widget_construction_callback is not None and callable( Widget._widget_construction_callback): Widget._widget_construction_callback(widget) @staticmethod def handle_comm_opened(comm, msg): """Static method, called when a widget is constructed.""" widget_class = import_item(str(msg['content']['data']['widget_class'])) widget = widget_class(comm=comm) #------------------------------------------------------------------------- # Traits #------------------------------------------------------------------------- _model_module = Unicode(None, allow_none=True, help="""A requirejs module name in which to find _model_name. If empty, look in the global registry.""" ) _model_name = Unicode('WidgetModel', help="""Name of the backbone model registered in the front-end to create and sync this widget with.""") _view_module = Unicode( help="""A requirejs module in which to find _view_name. If empty, look in the global registry.""", sync=True) _view_name = Unicode(None, allow_none=True, help="""Default view registered in the front-end to use to represent the widget.""", sync=True) comm = Instance('ipykernel.comm.Comm', allow_none=True) msg_throttle = Int(3, sync=True, help="""Maximum number of msgs the front-end can send before receiving an idle msg from the back-end.""") version = Int(0, sync=True, help="""Widget's version""") keys = List() def _keys_default(self): return [name for name in self.traits(sync=True)] _property_lock = Dict() _send_state_lock = Int(0) _states_to_send = Set() _display_callbacks = Instance(CallbackDispatcher, ()) _msg_callbacks = Instance(CallbackDispatcher, ()) #------------------------------------------------------------------------- # (Con/de)structor #------------------------------------------------------------------------- def __init__(self, **kwargs): """Public constructor""" self._model_id = kwargs.pop('model_id', None) super(Widget, self).__init__(**kwargs) Widget._call_widget_constructed(self) self.open() def __del__(self): """Object disposal""" self.close() #------------------------------------------------------------------------- # Properties #------------------------------------------------------------------------- def open(self): """Open a comm to the frontend if one isn't already open.""" if self.comm is None: args = dict(target_name='ipython.widget', data={ 'model_name': self._model_name, 'model_module': self._model_module }) if self._model_id is not None: args['comm_id'] = self._model_id self.comm = Comm(**args) def _comm_changed(self, name, new): """Called when the comm is changed.""" if new is None: return self._model_id = self.model_id self.comm.on_msg(self._handle_msg) Widget.widgets[self.model_id] = self # first update self.send_state() @property def model_id(self): """Gets the model id of this widget. If a Comm doesn't exist yet, a Comm will be created automagically.""" return self.comm.comm_id #------------------------------------------------------------------------- # Methods #------------------------------------------------------------------------- def close(self): """Close method. Closes the underlying comm. When the comm is closed, all of the widget views are automatically removed from the front-end.""" if self.comm is not None: Widget.widgets.pop(self.model_id, None) self.comm.close() self.comm = None def send_state(self, key=None): """Sends the widget state, or a piece of it, to the front-end. Parameters ---------- key : unicode, or iterable (optional) A single property's name or iterable of property names to sync with the front-end. """ state, buffer_keys, buffers = self.get_state(key=key) msg = {"method": "update", "state": state} if buffer_keys: msg['buffers'] = buffer_keys self._send(msg, buffers=buffers) def get_state(self, key=None): """Gets the widget state, or a piece of it. Parameters ---------- key : unicode or iterable (optional) A single property's name or iterable of property names to get. Returns ------- state : dict of states buffer_keys : list of strings the values that are stored in buffers buffers : list of binary memoryviews values to transmit in binary metadata : dict metadata for each field: {key: metadata} """ if key is None: keys = self.keys elif isinstance(key, string_types): keys = [key] elif isinstance(key, collections.Iterable): keys = key else: raise ValueError( "key must be a string, an iterable of keys, or None") state = {} buffers = [] buffer_keys = [] for k in keys: f = self.trait_metadata(k, 'to_json', self._trait_to_json) value = getattr(self, k) serialized = f(value) if isinstance(serialized, memoryview): buffers.append(serialized) buffer_keys.append(k) else: state[k] = serialized return state, buffer_keys, buffers def set_state(self, sync_data): """Called when a state is received from the front-end.""" # The order of these context managers is important. Properties must # be locked when the hold_trait_notification context manager is # released and notifications are fired. with self._lock_property(**sync_data), self.hold_trait_notifications(): for name in sync_data: if name in self.keys: from_json = self.trait_metadata(name, 'from_json', self._trait_from_json) setattr(self, name, from_json(sync_data[name])) def send(self, content, buffers=None): """Sends a custom msg to the widget model in the front-end. Parameters ---------- content : dict Content of the message to send. buffers : list of binary buffers Binary buffers to send with message """ self._send({"method": "custom", "content": content}, buffers=buffers) def on_msg(self, callback, remove=False): """(Un)Register a custom msg receive callback. Parameters ---------- callback: callable callback will be passed three arguments when a message arrives:: callback(widget, content, buffers) remove: bool True if the callback should be unregistered.""" self._msg_callbacks.register_callback(callback, remove=remove) def on_displayed(self, callback, remove=False): """(Un)Register a widget displayed callback. Parameters ---------- callback: method handler Must have a signature of:: callback(widget, **kwargs) kwargs from display are passed through without modification. remove: bool True if the callback should be unregistered.""" self._display_callbacks.register_callback(callback, remove=remove) def add_trait(self, traitname, trait): """Dynamically add a trait attribute to the Widget.""" super(Widget, self).add_trait(traitname, trait) if trait.get_metadata('sync'): self.keys.append(traitname) self.send_state(traitname) #------------------------------------------------------------------------- # Support methods #------------------------------------------------------------------------- @contextmanager def _lock_property(self, **properties): """Lock a property-value pair. The value should be the JSON state of the property. NOTE: This, in addition to the single lock for all state changes, is flawed. In the future we may want to look into buffering state changes back to the front-end.""" self._property_lock = properties try: yield finally: self._property_lock = {} @contextmanager def hold_sync(self): """Hold syncing any state until the context manager is released""" # We increment a value so that this can be nested. Syncing will happen when # all levels have been released. self._send_state_lock += 1 try: yield finally: self._send_state_lock -= 1 if self._send_state_lock == 0: self.send_state(self._states_to_send) self._states_to_send.clear() def _should_send_property(self, key, value): """Check the property lock (property_lock)""" to_json = self.trait_metadata(key, 'to_json', self._trait_to_json) if (key in self._property_lock and to_json(value) == self._property_lock[key]): return False elif self._send_state_lock > 0: self._states_to_send.add(key) return False else: return True # Event handlers @_show_traceback def _handle_msg(self, msg): """Called when a msg is received from the front-end""" data = msg['content']['data'] method = data['method'] # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one. if method == 'backbone': if 'sync_data' in data: # get binary buffers too sync_data = data['sync_data'] for i, k in enumerate(data.get('buffer_keys', [])): sync_data[k] = msg['buffers'][i] self.set_state(sync_data) # handles all methods # Handle a state request. elif method == 'request_state': self.send_state() # Handle a custom msg from the front-end. elif method == 'custom': if 'content' in data: self._handle_custom_msg(data['content'], msg['buffers']) # Catch remainder. else: self.log.error( 'Unknown front-end to back-end widget msg with method "%s"' % method) def _handle_custom_msg(self, content, buffers): """Called when a custom msg is received.""" self._msg_callbacks(self, content, buffers) def _notify_trait(self, name, old_value, new_value): """Called when a property has been changed.""" # Trigger default traitlet callback machinery. This allows any user # registered validation to be processed prior to allowing the widget # machinery to handle the state. LoggingConfigurable._notify_trait(self, name, old_value, new_value) # Send the state after the user registered callbacks for trait changes # have all fired (allows for user to validate values). if self.comm is not None and name in self.keys: # Make sure this isn't information that the front-end just sent us. if self._should_send_property(name, new_value): # Send new state to front-end self.send_state(key=name) def _handle_displayed(self, **kwargs): """Called when a view has been displayed for this widget instance""" self._display_callbacks(self, **kwargs) def _trait_to_json(self, x): """Convert a trait value to json.""" return x def _trait_from_json(self, x): """Convert json values to objects.""" return x def _ipython_display_(self, **kwargs): """Called when `IPython.display.display` is called on the widget.""" # Show view. if self._view_name is not None: self._send({"method": "display"}) self._handle_displayed(**kwargs) def _send(self, msg, buffers=None): """Sends a message to the model in the front-end.""" self.comm.send(data=msg, buffers=buffers)
class Widget(LoggingHasTraits): #------------------------------------------------------------------------- # Class attributes #------------------------------------------------------------------------- _widget_construction_callback = None # widgets is a dictionary of all active widget objects widgets = {} # widget_types is a registry of widgets by module, version, and name: widget_types = WidgetRegistry() @staticmethod def on_widget_constructed(callback): """Registers a callback to be called when a widget is constructed. The callback must have the following signature: callback(widget)""" Widget._widget_construction_callback = callback @staticmethod def _call_widget_constructed(widget): """Static method, called when a widget is constructed.""" if Widget._widget_construction_callback is not None and callable( Widget._widget_construction_callback): Widget._widget_construction_callback(widget) @staticmethod def handle_comm_opened(comm, msg): """Static method, called when a widget is constructed.""" version = msg.get('metadata', {}).get('version', '') if version.split('.')[0] != PROTOCOL_VERSION_MAJOR: raise ValueError( "Incompatible widget protocol versions: received version %r, expected version %r" % (version, __protocol_version__)) data = msg['content']['data'] state = data['state'] # Find the widget class to instantiate in the registered widgets widget_class = Widget.widget_types.get(state['_model_module'], state['_model_module_version'], state['_model_name'], state['_view_module'], state['_view_module_version'], state['_view_name']) widget = widget_class(comm=comm) if 'buffer_paths' in data: _put_buffers(state, data['buffer_paths'], msg['buffers']) widget.set_state(state) @staticmethod def get_manager_state(drop_defaults=False, widgets=None): """Returns the full state for a widget manager for embedding :param drop_defaults: when True, it will not include default value :param widgets: list with widgets to include in the state (or all widgets when None) :return: """ state = {} if widgets is None: widgets = Widget.widgets.values() for widget in widgets: state[widget.model_id] = widget._get_embed_state( drop_defaults=drop_defaults) return {'version_major': 2, 'version_minor': 0, 'state': state} def _get_embed_state(self, drop_defaults=False): state = { 'model_name': self._model_name, 'model_module': self._model_module, 'model_module_version': self._model_module_version } model_state, buffer_paths, buffers = _remove_buffers( self.get_state(drop_defaults=drop_defaults)) state['state'] = model_state if len(buffers) > 0: state['buffers'] = [{ 'encoding': 'base64', 'path': p, 'data': standard_b64encode(d).decode('ascii') } for p, d in zip(buffer_paths, buffers)] return state def get_view_spec(self): return dict(version_major=2, version_minor=0, model_id=self._model_id) #------------------------------------------------------------------------- # Traits #------------------------------------------------------------------------- _model_name = Unicode('WidgetModel', help="Name of the model.", read_only=True).tag(sync=True) _model_module = Unicode('@jupyter-widgets/base', help="The namespace for the model.", read_only=True).tag(sync=True) _model_module_version = Unicode( __jupyter_widgets_base_version__, help="A semver requirement for namespace version containing the model.", read_only=True).tag(sync=True) _view_name = Unicode(None, allow_none=True, help="Name of the view.").tag(sync=True) _view_module = Unicode(None, allow_none=True, help="The namespace for the view.").tag(sync=True) _view_module_version = Unicode( '', help= "A semver requirement for the namespace version containing the view." ).tag(sync=True) _view_count = Int( None, allow_none=True, help= "EXPERIMENTAL: The number of views of the model displayed in the frontend. This attribute is experimental and may change or be removed in the future. None signifies that views will not be tracked. Set this to 0 to start tracking view creation/deletion." ).tag(sync=True) comm = Instance('ipykernel.comm.Comm', allow_none=True) keys = List(help="The traits which are synced.") @default('keys') def _default_keys(self): return [name for name in self.traits(sync=True)] _property_lock = Dict() _holding_sync = False _states_to_send = Set() _display_callbacks = Instance(CallbackDispatcher, ()) _msg_callbacks = Instance(CallbackDispatcher, ()) #------------------------------------------------------------------------- # (Con/de)structor #------------------------------------------------------------------------- def __init__(self, **kwargs): """Public constructor""" self._model_id = kwargs.pop('model_id', None) super(Widget, self).__init__(**kwargs) Widget._call_widget_constructed(self) self.open() def __del__(self): """Object disposal""" self.close() #------------------------------------------------------------------------- # Properties #------------------------------------------------------------------------- def open(self): """Open a comm to the frontend if one isn't already open.""" if self.comm is None: state, buffer_paths, buffers = _remove_buffers(self.get_state()) args = dict(target_name='jupyter.widget', data={ 'state': state, 'buffer_paths': buffer_paths }, buffers=buffers, metadata={'version': __protocol_version__}) if self._model_id is not None: args['comm_id'] = self._model_id self.comm = Comm(**args) @observe('comm') def _comm_changed(self, change): """Called when the comm is changed.""" if change['new'] is None: return self._model_id = self.model_id self.comm.on_msg(self._handle_msg) Widget.widgets[self.model_id] = self @property def model_id(self): """Gets the model id of this widget. If a Comm doesn't exist yet, a Comm will be created automagically.""" return self.comm.comm_id #------------------------------------------------------------------------- # Methods #------------------------------------------------------------------------- def close(self): """Close method. Closes the underlying comm. When the comm is closed, all of the widget views are automatically removed from the front-end.""" if self.comm is not None: Widget.widgets.pop(self.model_id, None) self.comm.close() self.comm = None self._ipython_display_ = None def send_state(self, key=None): """Sends the widget state, or a piece of it, to the front-end. Parameters ---------- key : unicode, or iterable (optional) A single property's name or iterable of property names to sync with the front-end. """ state = self.get_state(key=key) state, buffer_paths, buffers = _remove_buffers(state) msg = { 'method': 'update', 'state': state, 'buffer_paths': buffer_paths } self._send(msg, buffers=buffers) def get_state(self, key=None, drop_defaults=False): """Gets the widget state, or a piece of it. Parameters ---------- key : unicode or iterable (optional) A single property's name or iterable of property names to get. Returns ------- state : dict of states metadata : dict metadata for each field: {key: metadata} """ if key is None: keys = self.keys elif isinstance(key, string_types): keys = [key] elif isinstance(key, collections.Iterable): keys = key else: raise ValueError( "key must be a string, an iterable of keys, or None") state = {} traits = self.traits() for k in keys: to_json = self.trait_metadata(k, 'to_json', self._trait_to_json) value = to_json(getattr(self, k), self) if not PY3 and isinstance(traits[k], Bytes) and isinstance( value, bytes): value = memoryview(value) if not drop_defaults or not self._compare(value, traits[k].default_value): state[k] = value return state def _is_numpy(self, x): return x.__class__.__name__ == 'ndarray' and x.__class__.__module__ == 'numpy' def _compare(self, a, b): if self._is_numpy(a) or self._is_numpy(b): import numpy as np return np.array_equal(a, b) else: return a == b def set_state(self, sync_data): """Called when a state is received from the front-end.""" # The order of these context managers is important. Properties must # be locked when the hold_trait_notification context manager is # released and notifications are fired. with self._lock_property(**sync_data), self.hold_trait_notifications(): for name in sync_data: if name in self.keys: from_json = self.trait_metadata(name, 'from_json', self._trait_from_json) self.set_trait(name, from_json(sync_data[name], self)) def send(self, content, buffers=None): """Sends a custom msg to the widget model in the front-end. Parameters ---------- content : dict Content of the message to send. buffers : list of binary buffers Binary buffers to send with message """ self._send({"method": "custom", "content": content}, buffers=buffers) def on_msg(self, callback, remove=False): """(Un)Register a custom msg receive callback. Parameters ---------- callback: callable callback will be passed three arguments when a message arrives:: callback(widget, content, buffers) remove: bool True if the callback should be unregistered.""" self._msg_callbacks.register_callback(callback, remove=remove) def on_displayed(self, callback, remove=False): """(Un)Register a widget displayed callback. Parameters ---------- callback: method handler Must have a signature of:: callback(widget, **kwargs) kwargs from display are passed through without modification. remove: bool True if the callback should be unregistered.""" self.comm.comm_log.write("widget::on_displayed !!!!!!!!!!\n") self.comm.comm_log.flush() self._display_callbacks.register_callback(callback, remove=remove) def add_traits(self, **traits): """Dynamically add trait attributes to the Widget.""" super(Widget, self).add_traits(**traits) for name, trait in traits.items(): if trait.get_metadata('sync'): self.keys.append(name) self.send_state(name) def notify_change(self, change): """Called when a property has changed.""" # Send the state to the frontend before the user-registered callbacks # are called. name = change['name'] if self.comm is not None and self.comm.kernel is not None: # Make sure this isn't information that the front-end just sent us. if name in self.keys and self._should_send_property( name, change['new']): # Send new state to front-end self.send_state(key=name) super(Widget, self).notify_change(change) def __repr__(self): return self._gen_repr_from_keys(self._repr_keys()) #------------------------------------------------------------------------- # Support methods #------------------------------------------------------------------------- @contextmanager def _lock_property(self, **properties): """Lock a property-value pair. The value should be the JSON state of the property. NOTE: This, in addition to the single lock for all state changes, is flawed. In the future we may want to look into buffering state changes back to the front-end.""" self._property_lock = properties try: yield finally: self._property_lock = {} @contextmanager def hold_sync(self): """Hold syncing any state until the outermost context manager exits""" if self._holding_sync is True: yield else: try: self._holding_sync = True yield finally: self._holding_sync = False self.send_state(self._states_to_send) self._states_to_send.clear() def _should_send_property(self, key, value): """Check the property lock (property_lock)""" to_json = self.trait_metadata(key, 'to_json', self._trait_to_json) # A roundtrip conversion through json in the comparison takes care of # idiosyncracies of how python data structures map to json, for example # tuples get converted to lists. if (key in self._property_lock and jsonloads( jsondumps(to_json(value, self))) == self._property_lock[key]): return False elif self._holding_sync: self._states_to_send.add(key) return False else: return True # Event handlers @_show_traceback def _handle_msg(self, msg): """Called when a msg is received from the front-end""" data = msg['content']['data'] method = data['method'] if method == 'update': if 'state' in data: state = data['state'] if 'buffer_paths' in data: _put_buffers(state, data['buffer_paths'], msg['buffers']) self.set_state(state) # Handle a state request. elif method == 'request_state': self.send_state() # Handle a custom msg from the front-end. elif method == 'custom': if 'content' in data: self._handle_custom_msg(data['content'], msg['buffers']) # Catch remainder. else: self.log.error( 'Unknown front-end to back-end widget msg with method "%s"' % method) def _handle_custom_msg(self, content, buffers): """Called when a custom msg is received.""" self._msg_callbacks(self, content, buffers) def _handle_displayed(self, **kwargs): """Called when a view has been displayed for this widget instance""" self.comm.comm_log.write("widget::_handle_displayed !!!!!!!!!!\n") self.comm.comm_log.flush() self._display_callbacks(self, **kwargs) @staticmethod def _trait_to_json(x, self): """Convert a trait value to json.""" return x @staticmethod def _trait_from_json(x, self): """Convert json values to objects.""" return x def _ipython_display_(self, **kwargs): """Called when `IPython.display.display` is called on the widget.""" self.comm.comm_log.write("widget::_ipython_display !!!!!!!!!!\n") self.comm.comm_log.flush() if self._view_name is not None: # The 'application/vnd.jupyter.widget-view+json' mimetype has not been registered yet. # See the registration process and naming convention at # http://tools.ietf.org/html/rfc6838 # and the currently registered mimetypes at # http://www.iana.org/assignments/media-types/media-types.xhtml. data = { 'text/plain': "A Jupyter Widget", 'application/vnd.jupyter.widget-view+json': { 'version_major': 2, 'version_minor': 0, 'model_id': self._model_id } } self.comm.comm_log.write( "Calling display(then _handle_displayed) with data -> %s\n" % data) self.comm.comm_log.flush() display(data, raw=True) self._handle_displayed(**kwargs) def _send(self, msg, buffers=None): """Sends a message to the model in the front-end.""" if self.comm is not None and self.comm.kernel is not None: self.comm.send(data=msg, buffers=buffers) def _repr_keys(self): traits = self.traits() for key in sorted(self.keys): # Exclude traits that start with an underscore if key[0] == '_': continue # Exclude traits who are equal to their default value value = getattr(self, key) trait = traits[key] if self._compare(value, trait.default_value): continue elif (isinstance(trait, (Container, Dict)) and trait.default_value == Undefined and len(value) == 0): # Empty container, and dynamic default will be empty continue yield key def _gen_repr_from_keys(self, keys): class_name = self.__class__.__name__ signature = ', '.join('%s=%r' % (key, getattr(self, key)) for key in keys) return '%s(%s)' % (class_name, signature)
class EnterpriseGatewayApp(JupyterApp): """Application that provisions Jupyter kernels and proxies HTTP/Websocket traffic to the kernels. - reads command line and environment variable settings - initializes managers and routes - creates a Tornado HTTP server - starts the Tornado event loop """ name = 'jupyter-enterprise-gateway' version = __version__ description = """ Jupyter Enterprise Gateway Provisions remote Jupyter kernels and proxies HTTP/Websocket traffic to them. """ # Also include when generating help options classes = [FileKernelSessionManager, RemoteMappingKernelManager] # Enable some command line shortcuts aliases = aliases # Server IP / PORT binding port_env = 'EG_PORT' port_default_value = 8888 port = Integer(port_default_value, config=True, help='Port on which to listen (EG_PORT env var)') @default('port') def port_default(self): return int( os.getenv(self.port_env, os.getenv('KG_PORT', self.port_default_value))) port_retries_env = 'EG_PORT_RETRIES' port_retries_default_value = 50 port_retries = Integer( port_retries_default_value, config=True, help="""Number of ports to try if the specified port is not available (EG_PORT_RETRIES env var)""") @default('port_retries') def port_retries_default(self): return int( os.getenv( self.port_retries_env, os.getenv('KG_PORT_RETRIES', self.port_retries_default_value))) ip_env = 'EG_IP' ip_default_value = '127.0.0.1' ip = Unicode(ip_default_value, config=True, help='IP address on which to listen (EG_IP env var)') @default('ip') def ip_default(self): return os.getenv(self.ip_env, os.getenv('KG_IP', self.ip_default_value)) # Base URL base_url_env = 'EG_BASE_URL' base_url_default_value = '/' base_url = Unicode( base_url_default_value, config=True, help= 'The base path for mounting all API resources (EG_BASE_URL env var)') @default('base_url') def base_url_default(self): return os.getenv(self.base_url_env, os.getenv('KG_BASE_URL', self.base_url_default_value)) # Token authorization auth_token_env = 'EG_AUTH_TOKEN' auth_token = Unicode( config=True, help= 'Authorization token required for all requests (EG_AUTH_TOKEN env var)' ) @default('auth_token') def _auth_token_default(self): return os.getenv(self.auth_token_env, os.getenv('KG_AUTH_TOKEN', '')) # Begin CORS headers allow_credentials_env = 'EG_ALLOW_CREDENTIALS' allow_credentials = Unicode( config=True, help= 'Sets the Access-Control-Allow-Credentials header. (EG_ALLOW_CREDENTIALS env var)' ) @default('allow_credentials') def allow_credentials_default(self): return os.getenv(self.allow_credentials_env, os.getenv('KG_ALLOW_CREDENTIALS', '')) allow_headers_env = 'EG_ALLOW_HEADERS' allow_headers = Unicode( config=True, help= 'Sets the Access-Control-Allow-Headers header. (EG_ALLOW_HEADERS env var)' ) @default('allow_headers') def allow_headers_default(self): return os.getenv(self.allow_headers_env, os.getenv('KG_ALLOW_HEADERS', '')) allow_methods_env = 'EG_ALLOW_METHODS' allow_methods = Unicode( config=True, help= 'Sets the Access-Control-Allow-Methods header. (EG_ALLOW_METHODS env var)' ) @default('allow_methods') def allow_methods_default(self): return os.getenv(self.allow_methods_env, os.getenv('KG_ALLOW_METHODS', '')) allow_origin_env = 'EG_ALLOW_ORIGIN' allow_origin = Unicode( config=True, help= 'Sets the Access-Control-Allow-Origin header. (EG_ALLOW_ORIGIN env var)' ) @default('allow_origin') def allow_origin_default(self): return os.getenv(self.allow_origin_env, os.getenv('KG_ALLOW_ORIGIN', '')) expose_headers_env = 'EG_EXPOSE_HEADERS' expose_headers = Unicode( config=True, help= 'Sets the Access-Control-Expose-Headers header. (EG_EXPOSE_HEADERS env var)' ) @default('expose_headers') def expose_headers_default(self): return os.getenv(self.expose_headers_env, os.getenv('KG_EXPOSE_HEADERS', '')) trust_xheaders_env = 'EG_TRUST_XHEADERS' trust_xheaders = CBool( False, config=True, help="""Use x-* header values for overriding the remote-ip, useful when application is behing a proxy. (EG_TRUST_XHEADERS env var)""" ) @default('trust_xheaders') def trust_xheaders_default(self): return strtobool( os.getenv(self.trust_xheaders_env, os.getenv('KG_TRUST_XHEADERS', 'False'))) certfile_env = 'EG_CERTFILE' certfile = Unicode( None, config=True, allow_none=True, help= 'The full path to an SSL/TLS certificate file. (EG_CERTFILE env var)') @default('certfile') def certfile_default(self): return os.getenv(self.certfile_env, os.getenv('KG_CERTFILE')) keyfile_env = 'EG_KEYFILE' keyfile = Unicode( None, config=True, allow_none=True, help= 'The full path to a private key file for usage with SSL/TLS. (EG_KEYFILE env var)' ) @default('keyfile') def keyfile_default(self): return os.getenv(self.keyfile_env, os.getenv('KG_KEYFILE')) client_ca_env = 'EG_CLIENT_CA' client_ca = Unicode( None, config=True, allow_none=True, help="""The full path to a certificate authority certificate for SSL/TLS client authentication. (EG_CLIENT_CA env var)""") @default('client_ca') def client_ca_default(self): return os.getenv(self.client_ca_env, os.getenv('KG_CLIENT_CA')) max_age_env = 'EG_MAX_AGE' max_age = Unicode( config=True, help='Sets the Access-Control-Max-Age header. (EG_MAX_AGE env var)') @default('max_age') def max_age_default(self): return os.getenv(self.max_age_env, os.getenv('KG_MAX_AGE', '')) # End CORS headers max_kernels_env = 'EG_MAX_KERNELS' max_kernels = Integer( None, config=True, allow_none=True, help= """Limits the number of kernel instances allowed to run by this gateway. Unbounded by default. (EG_MAX_KERNELS env var)""") @default('max_kernels') def max_kernels_default(self): val = os.getenv(self.max_kernels_env, os.getenv('KG_MAX_KERNELS')) return val if val is None else int(val) default_kernel_name_env = 'EG_DEFAULT_KERNEL_NAME' default_kernel_name = Unicode( config=True, help= 'Default kernel name when spawning a kernel (EG_DEFAULT_KERNEL_NAME env var)' ) @default('default_kernel_name') def default_kernel_name_default(self): # defaults to Jupyter's default kernel name on empty string return os.getenv(self.default_kernel_name_env, os.getenv('KG_DEFAULT_KERNEL_NAME', '')) list_kernels_env = 'EG_LIST_KERNELS' list_kernels = Bool( config=True, help= """Permits listing of the running kernels using API endpoints /api/kernels and /api/sessions. (EG_LIST_KERNELS env var) Note: Jupyter Notebook allows this by default but Jupyter Enterprise Gateway does not.""" ) @default('list_kernels') def list_kernels_default(self): return os.getenv(self.list_kernels_env, os.getenv('KG_LIST_KERNELS', 'False')).lower() == 'true' env_whitelist_env = 'EG_ENV_WHITELIST' env_whitelist = List( config=True, help="""Environment variables allowed to be set when a client requests a new kernel. (EG_ENV_WHITELIST env var)""") @default('env_whitelist') def env_whitelist_default(self): return os.getenv(self.env_whitelist_env, os.getenv('KG_ENV_WHITELIST', '')).split(',') env_process_whitelist_env = 'EG_ENV_PROCESS_WHITELIST' env_process_whitelist = List( config=True, help="""Environment variables allowed to be inherited from the spawning process by the kernel. (EG_ENV_PROCESS_WHITELIST env var)""" ) @default('env_process_whitelist') def env_process_whitelist_default(self): return os.getenv(self.env_process_whitelist_env, os.getenv('KG_ENV_PROCESS_WHITELIST', '')).split(',') # Remote hosts remote_hosts_env = 'EG_REMOTE_HOSTS' remote_hosts_default_value = 'localhost' remote_hosts = List( default_value=[remote_hosts_default_value], config=True, help= """Bracketed comma-separated list of hosts on which DistributedProcessProxy kernels will be launched e.g., ['host1','host2']. (EG_REMOTE_HOSTS env var - non-bracketed, just comma-separated)""") @default('remote_hosts') def remote_hosts_default(self): return os.getenv(self.remote_hosts_env, self.remote_hosts_default_value).split(',') # Yarn endpoint yarn_endpoint_env = 'EG_YARN_ENDPOINT' yarn_endpoint = Unicode( None, config=True, allow_none=True, help= """The http url specifying the YARN Resource Manager. Note: If this value is NOT set, the YARN library will use the files within the local HADOOP_CONFIG_DIR to determine the active resource manager. (EG_YARN_ENDPOINT env var)""" ) @default('yarn_endpoint') def yarn_endpoint_default(self): return os.getenv(self.yarn_endpoint_env) # Alt Yarn endpoint alt_yarn_endpoint_env = 'EG_ALT_YARN_ENDPOINT' alt_yarn_endpoint = Unicode( None, config=True, allow_none=True, help= """The http url specifying the alternate YARN Resource Manager. This value should be set when YARN Resource Managers are configured for high availability. Note: If both YARN endpoints are NOT set, the YARN library will use the files within the local HADOOP_CONFIG_DIR to determine the active resource manager. (EG_ALT_YARN_ENDPOINT env var)""") @default('alt_yarn_endpoint') def alt_yarn_endpoint_default(self): return os.getenv(self.alt_yarn_endpoint_env) yarn_endpoint_security_enabled_env = 'EG_YARN_ENDPOINT_SECURITY_ENABLED' yarn_endpoint_security_enabled_default_value = False yarn_endpoint_security_enabled = Bool( yarn_endpoint_security_enabled_default_value, config=True, help="""Is YARN Kerberos/SPNEGO Security enabled (True/False). (EG_YARN_ENDPOINT_SECURITY_ENABLED env var)""" ) @default('yarn_endpoint_security_enabled') def yarn_endpoint_security_enabled_default(self): return bool( os.getenv(self.yarn_endpoint_security_enabled_env, self.yarn_endpoint_security_enabled_default_value)) # Conductor endpoint conductor_endpoint_env = 'EG_CONDUCTOR_ENDPOINT' conductor_endpoint_default_value = None conductor_endpoint = Unicode( conductor_endpoint_default_value, config=True, help="""The http url for accessing the Conductor REST API. (EG_CONDUCTOR_ENDPOINT env var)""") @default('conductor_endpoint') def conductor_endpoint_default(self): return os.getenv(self.conductor_endpoint_env, self.conductor_endpoint_default_value) _log_formatter_cls = LogFormatter # traitlet default is LevelFormatter @default('log_format') def _default_log_format(self): """override default log format to include milliseconds""" return u"%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s]%(end_color)s %(message)s" # Impersonation enabled impersonation_enabled_env = 'EG_IMPERSONATION_ENABLED' impersonation_enabled = Bool( False, config=True, help= """Indicates whether impersonation will be performed during kernel launch. (EG_IMPERSONATION_ENABLED env var)""") @default('impersonation_enabled') def impersonation_enabled_default(self): return bool( os.getenv(self.impersonation_enabled_env, 'false').lower() == 'true') # Unauthorized users unauthorized_users_env = 'EG_UNAUTHORIZED_USERS' unauthorized_users_default_value = 'root' unauthorized_users = Set( default_value={unauthorized_users_default_value}, config=True, help= """Comma-separated list of user names (e.g., ['root','admin']) against which KERNEL_USERNAME will be compared. Any match (case-sensitive) will prevent the kernel's launch and result in an HTTP 403 (Forbidden) error. (EG_UNAUTHORIZED_USERS env var - non-bracketed, just comma-separated)""" ) @default('unauthorized_users') def unauthorized_users_default(self): return os.getenv(self.unauthorized_users_env, self.unauthorized_users_default_value).split(',') # Authorized users authorized_users_env = 'EG_AUTHORIZED_USERS' authorized_users = Set( config=True, help= """Comma-separated list of user names (e.g., ['bob','alice']) against which KERNEL_USERNAME will be compared. Any match (case-sensitive) will allow the kernel's launch, otherwise an HTTP 403 (Forbidden) error will be raised. The set of unauthorized users takes precedence. This option should be used carefully as it can dramatically limit who can launch kernels. (EG_AUTHORIZED_USERS env var - non-bracketed, just comma-separated)""") @default('authorized_users') def authorized_users_default(self): au_env = os.getenv(self.authorized_users_env) return au_env.split(',') if au_env is not None else [] # Port range port_range_env = 'EG_PORT_RANGE' port_range_default_value = "0..0" port_range = Unicode( port_range_default_value, config=True, help= """Specifies the lower and upper port numbers from which ports are created. The bounded values are separated by '..' (e.g., 33245..34245 specifies a range of 1000 ports to be randomly selected). A range of zero (e.g., 33245..33245 or 0..0) disables port-range enforcement. (EG_PORT_RANGE env var)""") @default('port_range') def port_range_default(self): return os.getenv(self.port_range_env, self.port_range_default_value) # Max Kernels per User max_kernels_per_user_env = 'EG_MAX_KERNELS_PER_USER' max_kernels_per_user_default_value = -1 max_kernels_per_user = Integer( max_kernels_per_user_default_value, config=True, help="""Specifies the maximum number of kernels a user can have active simultaneously. A value of -1 disables enforcement. (EG_MAX_KERNELS_PER_USER env var)""") @default('max_kernels_per_user') def max_kernels_per_user_default(self): return int( os.getenv(self.max_kernels_per_user_env, self.max_kernels_per_user_default_value)) ws_ping_interval_env = 'EG_WS_PING_INTERVAL_SECS' ws_ping_interval_default_value = 30 ws_ping_interval = Integer( ws_ping_interval_default_value, config=True, help= """Specifies the ping interval(in seconds) that should be used by zmq port associated withspawned kernels.Set this variable to 0 to disable ping mechanism. (EG_WS_PING_INTERVAL_SECS env var)""") @default('ws_ping_interval') def ws_ping_interval_default(self): return int( os.getenv(self.ws_ping_interval_env, self.ws_ping_interval_default_value)) kernel_spec_manager = Instance(KernelSpecManager, allow_none=True) kernel_spec_manager_class = Type(default_value=KernelSpecManager, config=True, help=""" The kernel spec manager class to use. Must be a subclass of `jupyter_client.kernelspec.KernelSpecManager`. """) kernel_manager_class = Type(klass=MappingKernelManager, default_value=RemoteMappingKernelManager, config=True, help=""" The kernel manager class to use. Must be a subclass of `notebook.services.kernels.MappingKernelManager`. """) kernel_session_manager_class = Type(klass=KernelSessionManager, default_value=FileKernelSessionManager, config=True, help=""" The kernel session manager class to use. Must be a subclass of `enterprise_gateway.services.sessions.KernelSessionManager`. """) def initialize(self, argv=None): """Initializes the base class, configurable manager instances, the Tornado web app, and the tornado HTTP server. Parameters ---------- argv Command line arguments """ super(EnterpriseGatewayApp, self).initialize(argv) self.init_configurables() self.init_webapp() self.init_http_server() def init_configurables(self): """Initializes all configurable objects including a kernel manager, kernel spec manager, session manager, and personality. """ self.kernel_spec_manager = KernelSpecManager(parent=self) # Only pass a default kernel name when one is provided. Otherwise, # adopt whatever default the kernel manager wants to use. kwargs = {} if self.default_kernel_name: kwargs['default_kernel_name'] = self.default_kernel_name self.kernel_spec_manager = self.kernel_spec_manager_class( parent=self, ) self.kernel_manager = self.kernel_manager_class( parent=self, log=self.log, connection_dir=self.runtime_dir, kernel_spec_manager=self.kernel_spec_manager, **kwargs) self.session_manager = SessionManager( log=self.log, kernel_manager=self.kernel_manager) self.kernel_session_manager = self.kernel_session_manager_class( parent=self, log=self.log, kernel_manager=self.kernel_manager, config=self.config, # required to get command-line options visible **kwargs) # Attempt to start persisted sessions self.kernel_session_manager.start_sessions() self.contents_manager = None # Gateways don't use contents manager def _create_request_handlers(self): """Create default Jupyter handlers and redefine them off of the base_url path. Assumes init_configurables() has already been called. """ handlers = [] # append tuples for the standard kernel gateway endpoints for handler in (default_api_handlers + default_kernel_handlers + default_kernelspec_handlers + default_session_handlers + default_base_handlers): # Create a new handler pattern rooted at the base_url pattern = url_path_join('/', self.base_url, handler[0]) # Some handlers take args, so retain those in addition to the # handler class ref new_handler = tuple([pattern] + list(handler[1:])) handlers.append(new_handler) return handlers def init_webapp(self): """Initializes Tornado web application with uri handlers. Adds the various managers and web-front configuration values to the Tornado settings for reference by the handlers. """ # Enable the same pretty logging the notebook uses enable_pretty_logging() # Configure the tornado logging level too logging.getLogger().setLevel(self.log_level) handlers = self._create_request_handlers() self.web_app = web.Application( handlers=handlers, kernel_manager=self.kernel_manager, session_manager=self.session_manager, contents_manager=self.contents_manager, kernel_spec_manager=self.kernel_spec_manager, eg_auth_token=self.auth_token, eg_allow_credentials=self.allow_credentials, eg_allow_headers=self.allow_headers, eg_allow_methods=self.allow_methods, eg_allow_origin=self.allow_origin, eg_expose_headers=self.expose_headers, eg_max_age=self.max_age, eg_max_kernels=self.max_kernels, eg_env_process_whitelist=self.env_process_whitelist, eg_env_whitelist=self.env_whitelist, eg_list_kernels=self.list_kernels, # Also set the allow_origin setting used by notebook so that the # check_origin method used everywhere respects the value allow_origin=self.allow_origin, # Always allow remote access (has been limited to localhost >= notebook 5.6) allow_remote_access=True, # setting ws_ping_interval value that can allow it to be modified for the purpose of toggling ping mechanism # for zmq web-sockets or increasing/decreasing web socket ping interval/timeouts. ws_ping_interval=self.ws_ping_interval * 1000) def _build_ssl_options(self): """Build a dictionary of SSL options for the tornado HTTP server. Taken directly from jupyter/notebook code. """ ssl_options = {} if self.certfile: ssl_options['certfile'] = self.certfile if self.keyfile: ssl_options['keyfile'] = self.keyfile if self.client_ca: ssl_options['ca_certs'] = self.client_ca if not ssl_options: # None indicates no SSL config ssl_options = None else: # SSL may be missing, so only import it if it's to be used import ssl # PROTOCOL_TLS selects the highest ssl/tls protocol version that both the client and # server support. When PROTOCOL_TLS is not available use PROTOCOL_SSLv23. # PROTOCOL_TLS is new in version 2.7.13, 3.5.3 and 3.6 ssl_options.setdefault( 'ssl_version', getattr(ssl, 'PROTOCOL_TLS', ssl.PROTOCOL_SSLv23)) if ssl_options.get('ca_certs', False): ssl_options.setdefault('cert_reqs', ssl.CERT_REQUIRED) return ssl_options def init_http_server(self): """Initializes a HTTP server for the Tornado web application on the configured interface and port. Tries to find an open port if the one configured is not available using the same logic as the Jupyer Notebook server. """ ssl_options = self._build_ssl_options() self.http_server = httpserver.HTTPServer(self.web_app, xheaders=self.trust_xheaders, ssl_options=ssl_options) for port in random_ports(self.port, self.port_retries + 1): try: self.http_server.listen(port, self.ip) except socket.error as e: if e.errno == errno.EADDRINUSE: self.log.info( 'The port %i is already in use, trying another port.' % port) continue elif e.errno in (errno.EACCES, getattr(errno, 'WSAEACCES', errno.EACCES)): self.log.warning("Permission to listen on port %i denied" % port) continue else: raise else: self.port = port break else: self.log.critical( 'ERROR: the gateway server could not be started because ' 'no available port could be found.') self.exit(1) def start(self): """Starts an IO loop for the application. """ super(EnterpriseGatewayApp, self).start() self.log.info( 'Jupyter Enterprise Gateway {} is available at http{}://{}:{}'. format(EnterpriseGatewayApp.version, 's' if self.keyfile else '', self.ip, self.port)) # If impersonation is enabled, issue a warning message if the gateway user is not in unauthorized_users. if self.impersonation_enabled: gateway_user = getpass.getuser() if gateway_user.lower() not in self.unauthorized_users: self.log.warning( "Impersonation is enabled and gateway user '{}' is NOT specified in the set of " "unauthorized users! Kernels may execute as that user with elevated privileges." .format(gateway_user)) self.io_loop = ioloop.IOLoop.current() if sys.platform != 'win32': signal.signal(signal.SIGHUP, signal.SIG_IGN) signal.signal(signal.SIGTERM, self._signal_stop) try: self.io_loop.start() except KeyboardInterrupt: self.log.info("Interrupted...") # Ignore further interrupts (ctrl-c) signal.signal(signal.SIGINT, signal.SIG_IGN) finally: self.shutdown() def shutdown(self): """Shuts down all running kernels.""" kids = self.kernel_manager.list_kernel_ids() for kid in kids: self.kernel_manager.shutdown_kernel(kid, now=True) def stop(self): """ Stops the HTTP server and IO loop associated with the application. """ def _stop(): self.http_server.stop() self.io_loop.stop() self.io_loop.add_callback(_stop) def _signal_stop(self, sig, frame): self.log.info("Received signal to terminate Enterprise Gateway.") self.io_loop.add_callback_from_signal(self.io_loop.stop)
class Session(Configurable): """Object for handling serialization and sending of messages. The Session object handles building messages and sending them with ZMQ sockets or ZMQStream objects. Objects can communicate with each other over the network via Session objects, and only need to work with the dict-based IPython message spec. The Session will handle serialization/deserialization, security, and metadata. Sessions support configurable serialization via packer/unpacker traits, and signing with HMAC digests via the key/keyfile traits. Parameters ---------- debug : bool whether to trigger extra debugging statements packer/unpacker : str : 'json', 'pickle' or import_string importstrings for methods to serialize message parts. If just 'json' or 'pickle', predefined JSON and pickle packers will be used. Otherwise, the entire importstring must be used. The functions must accept at least valid JSON input, and output *bytes*. For example, to use msgpack: packer = 'msgpack.packb', unpacker='msgpack.unpackb' pack/unpack : callables You can also set the pack/unpack callables for serialization directly. session : bytes the ID of this Session object. The default is to generate a new UUID. username : unicode username added to message headers. The default is to ask the OS. key : bytes The key used to initialize an HMAC signature. If unset, messages will not be signed or checked. keyfile : filepath The file containing a key. If this is set, `key` will be initialized to the contents of the file. """ debug = Bool(False, config=True, help="""Debug output in the Session""") check_pid = Bool(True, config=True, help="""Whether to check PID to protect against calls after fork. This check can be disabled if fork-safety is handled elsewhere. """) packer = DottedObjectName('json',config=True, help="""The name of the packer for serializing messages. Should be one of 'json', 'pickle', or an import name for a custom callable serializer.""") @observe('packer') def _packer_changed(self, change): new = change['new'] if new.lower() == 'json': self.pack = json_packer self.unpack = json_unpacker self.unpacker = new elif new.lower() == 'pickle': self.pack = pickle_packer self.unpack = pickle_unpacker self.unpacker = new else: self.pack = import_item(str(new)) unpacker = DottedObjectName('json', config=True, help="""The name of the unpacker for unserializing messages. Only used with custom functions for `packer`.""") @observe('unpacker') def _unpacker_changed(self, change): new = change['new'] if new.lower() == 'json': self.pack = json_packer self.unpack = json_unpacker self.packer = new elif new.lower() == 'pickle': self.pack = pickle_packer self.unpack = pickle_unpacker self.packer = new else: self.unpack = import_item(str(new)) session = CUnicode(u'', config=True, help="""The UUID identifying this session.""") def _session_default(self): u = new_id() self.bsession = u.encode('ascii') return u @observe('session') def _session_changed(self, change): self.bsession = self.session.encode('ascii') # bsession is the session as bytes bsession = CBytes(b'') username = Unicode(str_to_unicode(os.environ.get('USER', 'username')), help="""Username for the Session. Default is your system username.""", config=True) metadata = Dict({}, config=True, help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""") # if 0, no adapting to do. adapt_version = Integer(0) # message signature related traits: key = CBytes(config=True, help="""execution key, for signing messages.""") def _key_default(self): return new_id_bytes() @observe('key') def _key_changed(self, change): self._new_auth() signature_scheme = Unicode('hmac-sha256', config=True, help="""The digest scheme used to construct the message signatures. Must have the form 'hmac-HASH'.""") @observe('signature_scheme') def _signature_scheme_changed(self, change): new = change['new'] if not new.startswith('hmac-'): raise TraitError("signature_scheme must start with 'hmac-', got %r" % new) hash_name = new.split('-', 1)[1] try: self.digest_mod = getattr(hashlib, hash_name) except AttributeError: raise TraitError("hashlib has no such attribute: %s" % hash_name) self._new_auth() digest_mod = Any() def _digest_mod_default(self): return hashlib.sha256 auth = Instance(hmac.HMAC, allow_none=True) def _new_auth(self): if self.key: self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod) else: self.auth = None digest_history = Set() digest_history_size = Integer(2**16, config=True, help="""The maximum number of digests to remember. The digest history will be culled when it exceeds this value. """ ) keyfile = Unicode('', config=True, help="""path to file containing execution key.""") @observe('keyfile') def _keyfile_changed(self, change): with open(change['new'], 'rb') as f: self.key = f.read().strip() # for protecting against sends from forks pid = Integer() # serialization traits: pack = Any(default_packer) # the actual packer function @observe('pack') def _pack_changed(self, change): new = change['new'] if not callable(new): raise TypeError("packer must be callable, not %s"%type(new)) unpack = Any(default_unpacker) # the actual packer function @observe('unpack') def _unpack_changed(self, change): # unpacker is not checked - it is assumed to be new = change['new'] if not callable(new): raise TypeError("unpacker must be callable, not %s"%type(new)) # thresholds: copy_threshold = Integer(2**16, config=True, help="Threshold (in bytes) beyond which a buffer should be sent without copying.") buffer_threshold = Integer(MAX_BYTES, config=True, help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.") item_threshold = Integer(MAX_ITEMS, config=True, help="""The maximum number of items for a container to be introspected for custom serialization. Containers larger than this are pickled outright. """ ) def __init__(self, **kwargs): """create a Session object Parameters ---------- debug : bool whether to trigger extra debugging statements packer/unpacker : str : 'json', 'pickle' or import_string importstrings for methods to serialize message parts. If just 'json' or 'pickle', predefined JSON and pickle packers will be used. Otherwise, the entire importstring must be used. The functions must accept at least valid JSON input, and output *bytes*. For example, to use msgpack: packer = 'msgpack.packb', unpacker='msgpack.unpackb' pack/unpack : callables You can also set the pack/unpack callables for serialization directly. session : unicode (must be ascii) the ID of this Session object. The default is to generate a new UUID. bsession : bytes The session as bytes username : unicode username added to message headers. The default is to ask the OS. key : bytes The key used to initialize an HMAC signature. If unset, messages will not be signed or checked. signature_scheme : str The message digest scheme. Currently must be of the form 'hmac-HASH', where 'HASH' is a hashing function available in Python's hashlib. The default is 'hmac-sha256'. This is ignored if 'key' is empty. keyfile : filepath The file containing a key. If this is set, `key` will be initialized to the contents of the file. """ super(Session, self).__init__(**kwargs) self._check_packers() self.none = self.pack({}) # ensure self._session_default() if necessary, so bsession is defined: self.session self.pid = os.getpid() self._new_auth() if not self.key: get_logger().warning("Message signing is disabled. This is insecure and not recommended!") def clone(self): """Create a copy of this Session Useful when connecting multiple times to a given kernel. This prevents a shared digest_history warning about duplicate digests due to multiple connections to IOPub in the same process. .. versionadded:: 5.1 """ # make a copy new_session = type(self)() for name in self.traits(): setattr(new_session, name, getattr(self, name)) # fork digest_history new_session.digest_history = set() new_session.digest_history.update(self.digest_history) return new_session @property def msg_id(self): """always return new uuid""" return new_id() def _check_packers(self): """check packers for datetime support.""" pack = self.pack unpack = self.unpack # check simple serialization msg = dict(a=[1,'hi']) try: packed = pack(msg) except Exception as e: msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}" if self.packer == 'json': jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod else: jsonmsg = "" raise ValueError( msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg) ) # ensure packed message is bytes if not isinstance(packed, bytes): raise ValueError("message packed to %r, but bytes are required"%type(packed)) # check that unpack is pack's inverse try: unpacked = unpack(packed) assert unpacked == msg except Exception as e: msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}" if self.packer == 'json': jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod else: jsonmsg = "" raise ValueError( msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg) ) # check datetime support msg = dict(t=utcnow()) try: unpacked = unpack(pack(msg)) if isinstance(unpacked['t'], datetime): raise ValueError("Shouldn't deserialize to datetime") except Exception: self.pack = lambda o: pack(squash_dates(o)) self.unpack = lambda s: unpack(s) def msg_header(self, msg_type): return msg_header(self.msg_id, msg_type, self.username, self.session) def msg(self, msg_type, content=None, parent=None, header=None, metadata=None): """Return the nested message dict. This format is different from what is sent over the wire. The serialize/deserialize methods converts this nested message dict to the wire format, which is a list of message parts. """ msg = {} header = self.msg_header(msg_type) if header is None else header msg['header'] = header msg['msg_id'] = header['msg_id'] msg['msg_type'] = header['msg_type'] msg['parent_header'] = {} if parent is None else extract_header(parent) msg['content'] = {} if content is None else content msg['metadata'] = self.metadata.copy() if metadata is not None: msg['metadata'].update(metadata) return msg def sign(self, msg_list): """Sign a message with HMAC digest. If no auth, return b''. Parameters ---------- msg_list : list The [p_header,p_parent,p_content] part of the message list. """ if self.auth is None: return b'' h = self.auth.copy() for m in msg_list: h.update(m) return str_to_bytes(h.hexdigest()) def serialize(self, msg, ident=None): """Serialize the message components to bytes. This is roughly the inverse of deserialize. The serialize/deserialize methods work with full message lists, whereas pack/unpack work with the individual message parts in the message list. Parameters ---------- msg : dict or Message The next message dict as returned by the self.msg method. Returns ------- msg_list : list The list of bytes objects to be sent with the format:: [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent, p_metadata, p_content, buffer1, buffer2, ...] In this list, the ``p_*`` entities are the packed or serialized versions, so if JSON is used, these are utf8 encoded JSON strings. """ content = msg.get('content', {}) if content is None: content = self.none elif isinstance(content, dict): content = self.pack(content) elif isinstance(content, bytes): # content is already packed, as in a relayed message pass elif isinstance(content, unicode_type): # should be bytes, but JSON often spits out unicode content = content.encode('utf8') else: raise TypeError("Content incorrect type: %s"%type(content)) real_message = [self.pack(msg['header']), self.pack(msg['parent_header']), self.pack(msg['metadata']), content, ] to_send = [] if isinstance(ident, list): # accept list of idents to_send.extend(ident) elif ident is not None: to_send.append(ident) to_send.append(DELIM) signature = self.sign(real_message) to_send.append(signature) to_send.extend(real_message) return to_send def send(self, stream, msg_or_type, content=None, parent=None, ident=None, buffers=None, track=False, header=None, metadata=None): """Build and send a message via stream or socket. The message format used by this function internally is as follows: [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content, buffer1,buffer2,...] The serialize/deserialize methods convert the nested message dict into this format. Parameters ---------- stream : zmq.Socket or ZMQStream The socket-like object used to send the data. msg_or_type : str or Message/dict Normally, msg_or_type will be a msg_type unless a message is being sent more than once. If a header is supplied, this can be set to None and the msg_type will be pulled from the header. content : dict or None The content of the message (ignored if msg_or_type is a message). header : dict or None The header dict for the message (ignored if msg_to_type is a message). parent : Message or dict or None The parent or parent header describing the parent of this message (ignored if msg_or_type is a message). ident : bytes or list of bytes The zmq.IDENTITY routing path. metadata : dict or None The metadata describing the message buffers : list or None The already-serialized buffers to be appended to the message. track : bool Whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages. Returns ------- msg : dict The constructed message. """ if not isinstance(stream, zmq.Socket): # ZMQStreams and dummy sockets do not support tracking. track = False if isinstance(msg_or_type, (Message, dict)): # We got a Message or message dict, not a msg_type so don't # build a new Message. msg = msg_or_type buffers = buffers or msg.get('buffers', []) else: msg = self.msg(msg_or_type, content=content, parent=parent, header=header, metadata=metadata) if self.check_pid and not os.getpid() == self.pid: get_logger().warning("WARNING: attempted to send message from fork\n%s", msg ) return buffers = [] if buffers is None else buffers for idx, buf in enumerate(buffers): if isinstance(buf, memoryview): view = buf else: try: # check to see if buf supports the buffer protocol. view = memoryview(buf) except TypeError: raise TypeError("Buffer objects must support the buffer protocol.") # memoryview.contiguous is new in 3.3, # just skip the check on Python 2 if hasattr(view, 'contiguous') and not view.contiguous: # zmq requires memoryviews to be contiguous raise ValueError("Buffer %i (%r) is not contiguous" % (idx, buf)) if self.adapt_version: msg = adapt(msg, self.adapt_version) to_send = self.serialize(msg, ident) to_send.extend(buffers) longest = max([ len(s) for s in to_send ]) copy = (longest < self.copy_threshold) if buffers and track and not copy: # only really track when we are doing zero-copy buffers tracker = stream.send_multipart(to_send, copy=False, track=True) else: # use dummy tracker, which will be done immediately tracker = DONE stream.send_multipart(to_send, copy=copy) if self.debug: pprint.pprint(msg) pprint.pprint(to_send) pprint.pprint(buffers) msg['tracker'] = tracker return msg def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None): """Send a raw message via ident path. This method is used to send a already serialized message. Parameters ---------- stream : ZMQStream or Socket The ZMQ stream or socket to use for sending the message. msg_list : list The serialized list of messages to send. This only includes the [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of the message. ident : ident or list A single ident or a list of idents to use in sending. """ to_send = [] if isinstance(ident, bytes): ident = [ident] if ident is not None: to_send.extend(ident) to_send.append(DELIM) # Don't include buffers in signature (per spec). to_send.append(self.sign(msg_list[0:4])) to_send.extend(msg_list) stream.send_multipart(to_send, flags, copy=copy) def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True): """Receive and unpack a message. Parameters ---------- socket : ZMQStream or Socket The socket or stream to use in receiving. Returns ------- [idents], msg [idents] is a list of idents and msg is a nested message dict of same format as self.msg returns. """ if isinstance(socket, ZMQStream): socket = socket.socket try: msg_list = socket.recv_multipart(mode, copy=copy) except zmq.ZMQError as e: if e.errno == zmq.EAGAIN: # We can convert EAGAIN to None as we know in this case # recv_multipart won't return None. return None,None else: raise # split multipart message into identity list and message dict # invalid large messages can cause very expensive string comparisons idents, msg_list = self.feed_identities(msg_list, copy) try: return idents, self.deserialize(msg_list, content=content, copy=copy) except Exception as e: # TODO: handle it raise e def feed_identities(self, msg_list, copy=True): """Split the identities from the rest of the message. Feed until DELIM is reached, then return the prefix as idents and remainder as msg_list. This is easily broken by setting an IDENT to DELIM, but that would be silly. Parameters ---------- msg_list : a list of Message or bytes objects The message to be split. copy : bool flag determining whether the arguments are bytes or Messages Returns ------- (idents, msg_list) : two lists idents will always be a list of bytes, each of which is a ZMQ identity. msg_list will be a list of bytes or zmq.Messages of the form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and should be unpackable/unserializable via self.deserialize at this point. """ if copy: idx = msg_list.index(DELIM) return msg_list[:idx], msg_list[idx+1:] else: failed = True for idx,m in enumerate(msg_list): if m.bytes == DELIM: failed = False break if failed: raise ValueError("DELIM not in msg_list") idents, msg_list = msg_list[:idx], msg_list[idx+1:] return [m.bytes for m in idents], msg_list def _add_digest(self, signature): """add a digest to history to protect against replay attacks""" if self.digest_history_size == 0: # no history, never add digests return self.digest_history.add(signature) if len(self.digest_history) > self.digest_history_size: # threshold reached, cull 10% self._cull_digest_history() def _cull_digest_history(self): """cull the digest history Removes a randomly selected 10% of the digest history """ current = len(self.digest_history) n_to_cull = max(int(current // 10), current - self.digest_history_size) if n_to_cull >= current: self.digest_history = set() return to_cull = random.sample(self.digest_history, n_to_cull) self.digest_history.difference_update(to_cull) def deserialize(self, msg_list, content=True, copy=True): """Unserialize a msg_list to a nested message dict. This is roughly the inverse of serialize. The serialize/deserialize methods work with full message lists, whereas pack/unpack work with the individual message parts in the message list. Parameters ---------- msg_list : list of bytes or Message objects The list of message parts of the form [HMAC,p_header,p_parent, p_metadata,p_content,buffer1,buffer2,...]. content : bool (True) Whether to unpack the content dict (True), or leave it packed (False). copy : bool (True) Whether msg_list contains bytes (True) or the non-copying Message objects in each place (False). Returns ------- msg : dict The nested message dict with top-level keys [header, parent_header, content, buffers]. The buffers are returned as memoryviews. """ minlen = 5 message = {} if not copy: # pyzmq didn't copy the first parts of the message, so we'll do it for i in range(minlen): msg_list[i] = msg_list[i].bytes if self.auth is not None: signature = msg_list[0] if not signature: raise ValueError("Unsigned Message") if signature in self.digest_history: raise ValueError("Duplicate Signature: %r" % signature) if content: # Only store signature if we are unpacking content, don't store if just peeking. self._add_digest(signature) check = self.sign(msg_list[1:5]) if not compare_digest(signature, check): raise ValueError("Invalid Signature: %r" % signature) if not len(msg_list) >= minlen: raise TypeError("malformed message, must have at least %i elements"%minlen) header = self.unpack(msg_list[1]) message['header'] = extract_dates(header) message['msg_id'] = header['msg_id'] message['msg_type'] = header['msg_type'] message['parent_header'] = extract_dates(self.unpack(msg_list[2])) message['metadata'] = self.unpack(msg_list[3]) if content: message['content'] = self.unpack(msg_list[4]) else: message['content'] = msg_list[4] buffers = [memoryview(b) for b in msg_list[5:]] if buffers and buffers[0].shape is None: # force copy to workaround pyzmq #646 buffers = [memoryview(b.bytes) for b in msg_list[5:]] message['buffers'] = buffers if self.debug: pprint.pprint(message) # adapt to the current version return adapt(message) def unserialize(self, *args, **kwargs): warnings.warn( "Session.unserialize is deprecated. Use Session.deserialize.", DeprecationWarning, ) return self.deserialize(*args, **kwargs)
class Kernel(SingletonConfigurable): #--------------------------------------------------------------------------- # Kernel interface #--------------------------------------------------------------------------- # attribute to override with a GUI eventloop = Any(None) @observe('eventloop') def _update_eventloop(self, change): """schedule call to eventloop from IOLoop""" loop = ioloop.IOLoop.current() if change.new is not None: loop.add_callback(self.enter_eventloop) session = Instance(Session, allow_none=True) profile_dir = Instance('IPython.core.profiledir.ProfileDir', allow_none=True) shell_stream = Instance(ZMQStream, allow_none=True) shell_streams = List( help="""Deprecated shell_streams alias. Use shell_stream .. versionchanged:: 6.0 shell_streams is deprecated. Use shell_stream. """) @default("shell_streams") def _shell_streams_default(self): warnings.warn( "Kernel.shell_streams is deprecated in yapkernel 6.0. Use Kernel.shell_stream", DeprecationWarning, stacklevel=2, ) if self.shell_stream is not None: return [self.shell_stream] else: return [] @observe("shell_streams") def _shell_streams_changed(self, change): warnings.warn( "Kernel.shell_streams is deprecated in yapkernel 6.0. Use Kernel.shell_stream", DeprecationWarning, stacklevel=2, ) if len(change.new) > 1: warnings.warn( "Kernel only supports one shell stream. Additional streams will be ignored.", RuntimeWarning, stacklevel=2, ) if change.new: self.shell_stream = change.new[0] control_stream = Instance(ZMQStream, allow_none=True) debug_shell_socket = Any() control_thread = Any() iopub_socket = Any() iopub_thread = Any() stdin_socket = Any() log = Instance(logging.Logger, allow_none=True) # identities: int_id = Integer(-1) ident = Unicode() @default('ident') def _default_ident(self): return str(uuid.uuid4()) # This should be overridden by wrapper kernels that implement any real # language. language_info = { 'name': 'Prolog (YAP)', 'mimetype': 'text/x-prolog', 'file_extension': '.yap', } # any links that should go in the help menu help_links = List() # Private interface _darwin_app_nap = Bool( True, help="""Whether to use appnope for compatibility with OS X App Nap. Only affects OS X >= 10.9. """).tag(config=True) # track associations with current request _allow_stdin = Bool(False) _parents = Dict({"shell": {}, "control": {}}) _parent_ident = Dict({'shell': b'', 'control': b''}) @property def _parent_header(self): warnings.warn( "Kernel._parent_header is deprecated in yapkernel 6. Use .get_parent()", DeprecationWarning, stacklevel=2, ) return self.get_parent(channel="shell") # Time to sleep after flushing the stdout/err buffers in each execute # cycle. While this introduces a hard limit on the minimal latency of the # execute cycle, it helps prevent output synchronization problems for # clients. # Units are in seconds. The minimum zmq latency on local host is probably # ~150 microseconds, set this to 500us for now. We may need to increase it # a little if it's not enough after more interactive testing. _execute_sleep = Float(0.0005).tag(config=True) # Frequency of the kernel's event loop. # Units are in seconds, kernel subclasses for GUI toolkits may need to # adapt to milliseconds. _poll_interval = Float(0.01).tag(config=True) stop_on_error_timeout = Float( 0.0, config=True, help="""time (in seconds) to wait for messages to arrive when aborting queued requests after an error. Requests that arrive within this window after an error will be cancelled. Increase in the event of unusually slow network causing significant delays, which can manifest as e.g. "Run all" in a notebook aborting some, but not all, messages after an error. """) # If the shutdown was requested over the network, we leave here the # necessary reply message so it can be sent by our registered atexit # handler. This ensures that the reply is only sent to clients truly at # the end of our shutdown process (which happens after the underlying # IPython shell's own shutdown). _shutdown_message = None # This is a dict of port number that the kernel is listening on. It is set # by record_ports and used by connect_request. _recorded_ports = Dict() # set of aborted msg_ids aborted = Set() # Track execution count here. For IPython, we override this to use the # execution count we store in the shell. execution_count = 0 msg_types = [ 'execute_request', 'complete_request', 'inspect_request', 'history_request', 'comm_info_request', 'kernel_info_request', 'connect_request', 'shutdown_request', 'is_complete_request', 'interrupt_request', # deprecated: 'apply_request', ] # add deprecated ipyparallel control messages control_msg_types = msg_types + [ 'clear_request', 'abort_request', 'debug_request' ] def __init__(self, **kwargs): super(Kernel, self).__init__(**kwargs) # Build dict of handlers for message types self.shell_handlers = {} for msg_type in self.msg_types: self.shell_handlers[msg_type] = getattr(self, msg_type) self.control_handlers = {} for msg_type in self.control_msg_types: self.control_handlers[msg_type] = getattr(self, msg_type) self.control_queue = Queue() def dispatch_control(self, msg): self.control_queue.put_nowait(msg) async def poll_control_queue(self): while True: msg = await self.control_queue.get() # handle tracers from _flush_control_queue if isinstance(msg, (concurrent.futures.Future, asyncio.Future)): msg.set_result(None) continue await self.process_control(msg) async def _flush_control_queue(self): """Flush the control queue, wait for processing of any pending messages""" if self.control_thread: control_loop = self.control_thread.io_loop # concurrent.futures.Futures are threadsafe # and can be used to await across threads tracer_future = concurrent.futures.Future() awaitable_future = asyncio.wrap_future(tracer_future) else: control_loop = self.io_loop tracer_future = awaitable_future = asyncio.Future() def _flush(): # control_stream.flush puts messages on the queue self.control_stream.flush() # put Future on the queue after all of those, # so we can wait for all queued messages to be processed self.control_queue.put(tracer_future) control_loop.add_callback(_flush) return awaitable_future async def process_control(self, msg): """dispatch control requests""" idents, msg = self.session.feed_identities(msg, copy=False) try: msg = self.session.deserialize(msg, content=True, copy=False) except Exception: self.log.error("Invalid Control Message", exc_info=True) return self.log.debug("Control received: %s", msg) # Set the parent message for side effects. self.set_parent(idents, msg, channel='control') self._publish_status('busy', 'control') header = msg['header'] msg_type = header['msg_type'] handler = self.control_handlers.get(msg_type, None) if handler is None: self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type) else: try: result = handler(self.control_stream, idents, msg) if inspect.isawaitable(result): await result except Exception: self.log.error("Exception in control handler:", exc_info=True) sys.stdout.flush() sys.stderr.flush() self._publish_status('idle', 'control') # flush to ensure reply is sent self.control_stream.flush(zmq.POLLOUT) def should_handle(self, stream, msg, idents): """Check whether a shell-channel message should be handled Allows subclasses to prevent handling of certain messages (e.g. aborted requests). """ msg_id = msg['header']['msg_id'] if msg_id in self.aborted: # is it safe to assume a msg_id will not be resubmitted? self.aborted.remove(msg_id) self._send_abort_reply(stream, msg, idents) return False return True async def dispatch_shell(self, msg): """dispatch shell requests""" # flush control queue before handling shell requests await self._flush_control_queue() idents, msg = self.session.feed_identities(msg, copy=False) try: msg = self.session.deserialize(msg, content=True, copy=False) except Exception: self.log.error("Invalid Message", exc_info=True) return # Set the parent message for side effects. self.set_parent(idents, msg, channel='shell') self._publish_status('busy', 'shell') msg_type = msg['header']['msg_type'] # Only abort execute requests if self._aborting and msg_type == 'execute_request': self._send_abort_reply(self.shell_stream, msg, idents) self._publish_status('idle', 'shell') # flush to ensure reply is sent before # handling the next request self.shell_stream.flush(zmq.POLLOUT) return # Print some info about this message and leave a '--->' marker, so it's # easier to trace visually the message chain when debugging. Each # handler prints its message at the end. self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type) self.log.debug(' Content: %s\n --->\n ', msg['content']) if not self.should_handle(self.shell_stream, msg, idents): return handler = self.shell_handlers.get(msg_type, None) if handler is None: self.log.warning("Unknown message type: %r", msg_type) else: self.log.debug("%s: %s", msg_type, msg) try: self.pre_handler_hook() except Exception: self.log.debug("Unable to signal in pre_handler_hook:", exc_info=True) try: result = handler(self.shell_stream, idents, msg) if inspect.isawaitable(result): await result except Exception: self.log.error("Exception in message handler:", exc_info=True) except KeyboardInterrupt: # Ctrl-c shouldn't crash the kernel here. self.log.error("KeyboardInterrupt caught in kernel.") finally: try: self.post_handler_hook() except Exception: self.log.debug("Unable to signal in post_handler_hook:", exc_info=True) sys.stdout.flush() sys.stderr.flush() self._publish_status('idle', 'shell') # flush to ensure reply is sent before # handling the next request self.shell_stream.flush(zmq.POLLOUT) def pre_handler_hook(self): """Hook to execute before calling message handler""" # ensure default_int_handler during handler call self.saved_sigint_handler = signal(SIGINT, default_int_handler) def post_handler_hook(self): """Hook to execute after calling message handler""" signal(SIGINT, self.saved_sigint_handler) def enter_eventloop(self): """enter eventloop""" self.log.info("Entering eventloop %s", self.eventloop) # record handle, so we can check when this changes eventloop = self.eventloop if eventloop is None: self.log.info("Exiting as there is no eventloop") return def advance_eventloop(): # check if eventloop changed: if self.eventloop is not eventloop: self.log.info("exiting eventloop %s", eventloop) return if self.msg_queue.qsize(): self.log.debug("Delaying eventloop due to waiting messages") # still messages to process, make the eventloop wait schedule_next() return self.log.debug("Advancing eventloop %s", eventloop) try: eventloop(self) except KeyboardInterrupt: # Ctrl-C shouldn't crash the kernel self.log.error("KeyboardInterrupt caught in kernel") pass if self.eventloop is eventloop: # schedule advance again schedule_next() def schedule_next(): """Schedule the next advance of the eventloop""" # flush the eventloop every so often, # giving us a chance to handle messages in the meantime self.log.debug("Scheduling eventloop advance") self.io_loop.call_later(0.001, advance_eventloop) # begin polling the eventloop schedule_next() async def do_one_iteration(self): """Process a single shell message Any pending control messages will be flushed as well .. versionchanged:: 5 This is now a coroutine """ # flush messages off of shell stream into the message queue self.shell_stream.flush() # process at most one shell message per iteration await self.process_one(wait=False) async def process_one(self, wait=True): """Process one request Returns None if no message was handled. """ if wait: t, dispatch, args = await self.msg_queue.get() else: try: t, dispatch, args = self.msg_queue.get_nowait() except asyncio.QueueEmpty: return None await dispatch(*args) async def dispatch_queue(self): """Coroutine to preserve order of message handling Ensures that only one message is processing at a time, even when the handler is async """ while True: try: await self.process_one() except Exception: self.log.exception("Error in message handler") _message_counter = Any(help="""Monotonic counter of messages """, ) @default('_message_counter') def _message_counter_default(self): return itertools.count() def schedule_dispatch(self, dispatch, *args): """schedule a message for dispatch""" idx = next(self._message_counter) self.msg_queue.put_nowait(( idx, dispatch, args, )) # ensure the eventloop wakes up self.io_loop.add_callback(lambda: None) def start(self): """register dispatchers for streams""" self.io_loop = ioloop.IOLoop.current() self.msg_queue = Queue() self.io_loop.add_callback(self.dispatch_queue) self.control_stream.on_recv(self.dispatch_control, copy=False) if self.control_thread: control_loop = self.control_thread.io_loop else: control_loop = self.io_loop asyncio.run_coroutine_threadsafe(self.poll_control_queue(), control_loop.asyncio_loop) self.shell_stream.on_recv( partial( self.schedule_dispatch, self.dispatch_shell, ), copy=False, ) # publish idle status self._publish_status('starting', 'shell') def record_ports(self, ports): """Record the ports that this kernel is using. The creator of the Kernel instance must call this methods if they want the :meth:`connect_request` method to return the port numbers. """ self._recorded_ports = ports #--------------------------------------------------------------------------- # Kernel request handlers #--------------------------------------------------------------------------- def _publish_execute_input(self, code, parent, execution_count): """Publish the code request on the iopub stream.""" self.session.send(self.iopub_socket, 'execute_input', { 'code': code, 'execution_count': execution_count }, parent=parent, ident=self._topic('execute_input')) def _publish_status(self, status, channel, parent=None): """send status (busy/idle) on IOPub""" self.session.send( self.iopub_socket, "status", {"execution_state": status}, parent=parent or self.get_parent(channel), ident=self._topic("status"), ) def _publish_debug_event(self, event): self.session.send( self.iopub_socket, "debug_event", event, parent=self.get_parent("control"), ident=self._topic("debug_event"), ) def set_parent(self, ident, parent, channel='shell'): """Set the current parent request Side effects (IOPub messages) and replies are associated with the request that caused them via the parent_header. The parent identity is used to route input_request messages on the stdin channel. """ self._parent_ident[channel] = ident self._parents[channel] = parent def get_parent(self, channel="shell"): """Get the parent request associated with a channel. .. versionadded:: 6 Parameters ---------- channel : str the name of the channel ('shell' or 'control') Returns ------- message : dict the parent message for the most recent request on the channel. """ return self._parents.get(channel, {}) def send_response(self, stream, msg_or_type, content=None, ident=None, buffers=None, track=False, header=None, metadata=None, channel='shell'): """Send a response to the message we're currently processing. This accepts all the parameters of :meth:`jupyter_client.session.Session.send` except ``parent``. This relies on :meth:`set_parent` having been called for the current message. """ return self.session.send( stream, msg_or_type, content, self.get_parent(channel), ident, buffers, track, header, metadata, ) def init_metadata(self, parent): """Initialize metadata. Run at the beginning of execution requests. """ # FIXME: `started` is part of ipyparallel # Remove for yapkernel 5.0 return { 'started': now(), } def finish_metadata(self, parent, metadata, reply_content): """Finish populating metadata. Run after completing an execution request. """ return metadata async def execute_request(self, stream, ident, parent): """handle an execute_request""" try: content = parent['content'] code = content['code'] silent = content['silent'] store_history = content.get('store_history', not silent) user_expressions = content.get('user_expressions', {}) allow_stdin = content.get('allow_stdin', False) except Exception: self.log.error("Got bad msg: ") self.log.error("%s", parent) return stop_on_error = content.get('stop_on_error', True) metadata = self.init_metadata(parent) # Re-broadcast our input for the benefit of listening clients, and # start computing output if not silent: self.execution_count += 1 self._publish_execute_input(code, parent, self.execution_count) reply_content = self.do_execute( code, silent, store_history, user_expressions, allow_stdin, ) if inspect.isawaitable(reply_content): reply_content = await reply_content # Flush output before sending the reply. sys.stdout.flush() sys.stderr.flush() # FIXME: on rare occasions, the flush doesn't seem to make it to the # clients... This seems to mitigate the problem, but we definitely need # to better understand what's going on. if self._execute_sleep: time.sleep(self._execute_sleep) # Send the reply. reply_content = json_clean(reply_content) metadata = self.finish_metadata(parent, metadata, reply_content) reply_msg = self.session.send(stream, 'execute_reply', reply_content, parent, metadata=metadata, ident=ident) self.log.debug("%s", reply_msg) if not silent and reply_msg['content'][ 'status'] == 'error' and stop_on_error: await self._abort_queues() def do_execute(self, code, silent, store_history=True, user_expressions=None, allow_stdin=False): """Execute user code. Must be overridden by subclasses. """ raise NotImplementedError async def complete_request(self, stream, ident, parent): content = parent['content'] code = content['code'] cursor_pos = content['cursor_pos'] matches = self.do_complete(code, cursor_pos) if inspect.isawaitable(matches): matches = await matches matches = json_clean(matches) self.session.send(stream, "complete_reply", matches, parent, ident) def do_complete(self, code, cursor_pos): """Override in subclasses to find completions. """ return { 'matches': [], 'cursor_end': cursor_pos, 'cursor_start': cursor_pos, 'metadata': {}, 'status': 'ok' } async def inspect_request(self, stream, ident, parent): content = parent['content'] reply_content = self.do_inspect( content['code'], content['cursor_pos'], content.get('detail_level', 0), ) if inspect.isawaitable(reply_content): reply_content = await reply_content # Before we send this object over, we scrub it for JSON usage reply_content = json_clean(reply_content) msg = self.session.send(stream, 'inspect_reply', reply_content, parent, ident) self.log.debug("%s", msg) def do_inspect(self, code, cursor_pos, detail_level=0): """Override in subclasses to allow introspection. """ return {'status': 'ok', 'data': {}, 'metadata': {}, 'found': False} async def history_request(self, stream, ident, parent): content = parent['content'] reply_content = self.do_history(**content) if inspect.isawaitable(reply_content): reply_content = await reply_content reply_content = json_clean(reply_content) msg = self.session.send(stream, 'history_reply', reply_content, parent, ident) self.log.debug("%s", msg) def do_history(self, hist_access_type, output, raw, session=None, start=None, stop=None, n=None, pattern=None, unique=False): """Override in subclasses to access history. """ return {'status': 'ok', 'history': []} async def connect_request(self, stream, ident, parent): if self._recorded_ports is not None: content = self._recorded_ports.copy() else: content = {} content['status'] = 'ok' msg = self.session.send(stream, 'connect_reply', content, parent, ident) self.log.debug("%s", msg) @property def kernel_info(self): return { 'protocol_version': kernel_protocol_version, 'implementation': self.implementation, 'implementation_version': self.implementation_version, 'language_info': self.language_info, 'banner': self.banner, 'help_links': self.help_links, } async def kernel_info_request(self, stream, ident, parent): content = {'status': 'ok'} content.update(self.kernel_info) msg = self.session.send(stream, 'kernel_info_reply', content, parent, ident) self.log.debug("%s", msg) async def comm_info_request(self, stream, ident, parent): content = parent['content'] target_name = content.get('target_name', None) # Should this be moved to ipkernel? if hasattr(self, 'comm_manager'): comms = { k: dict(target_name=v.target_name) for (k, v) in self.comm_manager.comms.items() if v.target_name == target_name or target_name is None } else: comms = {} reply_content = dict(comms=comms, status='ok') msg = self.session.send(stream, 'comm_info_reply', reply_content, parent, ident) self.log.debug("%s", msg) async def interrupt_request(self, stream, ident, parent): pid = os.getpid() pgid = os.getpgid(pid) if os.name == "nt": self.log.error("Interrupt message not supported on Windows") else: # Prefer process-group over process if pgid and hasattr(os, "killpg"): try: os.killpg(pgid, SIGINT) return except OSError: pass try: os.kill(pid, SIGINT) except OSError: pass content = parent['content'] self.session.send(stream, 'interrupt_reply', content, parent, ident=ident) return async def shutdown_request(self, stream, ident, parent): content = self.do_shutdown(parent['content']['restart']) if inspect.isawaitable(content): content = await content self.session.send(stream, 'shutdown_reply', content, parent, ident=ident) # same content, but different msg_id for broadcasting on IOPub self._shutdown_message = self.session.msg('shutdown_reply', content, parent) self._at_shutdown() self.log.debug('Stopping control ioloop') control_io_loop = self.control_stream.io_loop control_io_loop.add_callback(control_io_loop.stop) self.log.debug('Stopping shell ioloop') shell_io_loop = self.shell_stream.io_loop shell_io_loop.add_callback(shell_io_loop.stop) def do_shutdown(self, restart): """Override in subclasses to do things when the frontend shuts down the kernel. """ return {'status': 'ok', 'restart': restart} async def is_complete_request(self, stream, ident, parent): content = parent['content'] code = content['code'] reply_content = self.do_is_complete(code) if inspect.isawaitable(reply_content): reply_content = await reply_content reply_content = json_clean(reply_content) reply_msg = self.session.send(stream, 'is_complete_reply', reply_content, parent, ident) self.log.debug("%s", reply_msg) def do_is_complete(self, code): """Override in subclasses to find completions. """ return {'status': 'unknown'} async def debug_request(self, stream, ident, parent): content = parent['content'] reply_content = self.do_debug_request(content) if inspect.isawaitable(reply_content): reply_content = await reply_content reply_content = json_clean(reply_content) reply_msg = self.session.send(stream, 'debug_reply', reply_content, parent, ident) self.log.debug("%s", reply_msg) async def do_debug_request(self, msg): raise NotImplementedError #--------------------------------------------------------------------------- # Engine methods (DEPRECATED) #--------------------------------------------------------------------------- async def apply_request(self, stream, ident, parent): self.log.warning( "apply_request is deprecated in kernel_base, moving to ipyparallel." ) try: content = parent['content'] bufs = parent['buffers'] msg_id = parent['header']['msg_id'] except Exception: self.log.error("Got bad msg: %s", parent, exc_info=True) return md = self.init_metadata(parent) reply_content, result_buf = self.do_apply(content, bufs, msg_id, md) # flush i/o sys.stdout.flush() sys.stderr.flush() md = self.finish_metadata(parent, md, reply_content) self.session.send(stream, 'apply_reply', reply_content, parent=parent, ident=ident, buffers=result_buf, metadata=md) def do_apply(self, content, bufs, msg_id, reply_metadata): """DEPRECATED""" raise NotImplementedError #--------------------------------------------------------------------------- # Control messages (DEPRECATED) #--------------------------------------------------------------------------- async def abort_request(self, stream, ident, parent): """abort a specific msg by id""" self.log.warning( "abort_request is deprecated in kernel_base. It is only part of IPython parallel" ) msg_ids = parent['content'].get('msg_ids', None) if isinstance(msg_ids, str): msg_ids = [msg_ids] if not msg_ids: self._abort_queues() for mid in msg_ids: self.aborted.add(str(mid)) content = dict(status='ok') reply_msg = self.session.send(stream, 'abort_reply', content=content, parent=parent, ident=ident) self.log.debug("%s", reply_msg) async def clear_request(self, stream, idents, parent): """Clear our namespace.""" self.log.warning( "clear_request is deprecated in kernel_base. It is only part of IPython parallel" ) content = self.do_clear() self.session.send(stream, 'clear_reply', ident=idents, parent=parent, content=content) def do_clear(self): """DEPRECATED since 4.0.3""" raise NotImplementedError #--------------------------------------------------------------------------- # Protected interface #--------------------------------------------------------------------------- def _topic(self, topic): """prefixed topic for IOPub messages""" base = "kernel.%s" % self.ident return ("%s.%s" % (base, topic)).encode() _aborting = Bool(False) async def _abort_queues(self): self.shell_stream.flush() self._aborting = True def stop_aborting(): self.log.info("Finishing abort") self._aborting = False asyncio.get_event_loop().call_later(self.stop_on_error_timeout, stop_aborting) def _send_abort_reply(self, stream, msg, idents): """Send a reply to an aborted request""" self.log.info( f"Aborting {msg['header']['msg_id']}: {msg['header']['msg_type']}") reply_type = msg["header"]["msg_type"].rsplit("_", 1)[0] + "_reply" status = {"status": "aborted"} md = self.init_metadata(msg) md = self.finish_metadata(msg, md, status) md.update(status) self.session.send( stream, reply_type, metadata=md, content=status, parent=msg, ident=idents, ) def _no_raw_input(self): """Raise StdinNotImplementedError if active frontend doesn't support stdin.""" raise StdinNotImplementedError("raw_input was called, but this " "frontend does not support stdin.") def getpass(self, prompt='', stream=None): """Forward getpass to frontends Raises ------ StdinNotImplementedError if active frontend doesn't support stdin. """ if not self._allow_stdin: raise StdinNotImplementedError( "getpass was called, but this frontend does not support input requests." ) if stream is not None: import warnings warnings.warn( "The `stream` parameter of `getpass.getpass` will have no effect when using yapkernel", UserWarning, stacklevel=2, ) return self._input_request( prompt, self._parent_ident["shell"], self.get_parent("shell"), password=True, ) def raw_input(self, prompt=''): """Forward raw_input to frontends Raises ------ StdinNotImplementedError if active frontend doesn't support stdin. """ if not self._allow_stdin: raise StdinNotImplementedError( "raw_input was called, but this frontend does not support input requests." ) return self._input_request( str(prompt), self._parent_ident["shell"], self.get_parent("shell"), password=False, ) def _input_request(self, prompt, ident, parent, password=False): # Flush output before making the request. sys.stderr.flush() sys.stdout.flush() # flush the stdin socket, to purge stale replies while True: try: self.stdin_socket.recv_multipart(zmq.NOBLOCK) except zmq.ZMQError as e: if e.errno == zmq.EAGAIN: break else: raise # Send the input request. content = json_clean(dict(prompt=prompt, password=password)) self.session.send(self.stdin_socket, 'input_request', content, parent, ident=ident) # Await a response. while True: try: # Use polling with select() so KeyboardInterrupts can get # through; doing a blocking recv() means stdin reads are # uninterruptible on Windows. We need a timeout because # zmq.select() is also uninterruptible, but at least this # way reads get noticed immediately and KeyboardInterrupts # get noticed fairly quickly by human response time standards. rlist, _, xlist = zmq.select([self.stdin_socket], [], [self.stdin_socket], 0.01) if rlist or xlist: ident, reply = self.session.recv(self.stdin_socket) if (ident, reply) != (None, None): break except KeyboardInterrupt: # re-raise KeyboardInterrupt, to truncate traceback raise KeyboardInterrupt("Interrupted by user") from None except Exception: self.log.warning("Invalid Message:", exc_info=True) try: value = reply["content"]["value"] except Exception: self.log.error("Bad input_reply: %s", parent) value = '' if value == '\x04': # EOF raise EOFError return value def _at_shutdown(self): """Actions taken at shutdown by the kernel, called by python's atexit. """ if self._shutdown_message is not None: self.session.send(self.iopub_socket, self._shutdown_message, ident=self._topic('shutdown')) self.log.debug("%s", self._shutdown_message) self.control_stream.flush(zmq.POLLOUT)
class GlobusOAuthenticator(OAuthenticator): """The Globus OAuthenticator handles both authorization and passing transfer tokens to the spawner.""" login_service = 'Globus' logout_handler = GlobusLogoutHandler @default("userdata_url") def _userdata_url_default(self): return "https://auth.globus.org/v2/oauth2/userinfo" @default("authorize_url") def _authorize_url_default(self): return "https://auth.globus.org/v2/oauth2/authorize" @default("revocation_url") def _revocation_url_default(self): return "https://auth.globus.org/v2/oauth2/token/revoke" revocation_url = Unicode(help="Globus URL to revoke live tokens.").tag(config=True) @default("token_url") def _token_url_default(self): return "https://auth.globus.org/v2/oauth2/token" globus_groups_url = Unicode(help="Globus URL to get list of user's Groups.").tag( config=True ) @default("globus_groups_url") def _globus_groups_url_default(self): return "https://groups.api.globus.org/v2/groups/my_groups" identity_provider = Unicode( help="""Restrict which institution a user can use to login (GlobusID, University of Hogwarts, etc.). This should be set in the app at developers.globus.org, but this acts as an additional check to prevent unnecessary account creation.""" ).tag(config=True) def _identity_provider_default(self): return os.getenv('IDENTITY_PROVIDER', '') username_from_email = Bool( help="""Create username from email address, not preferred username. If an identity provider is specified, email address must be from the same domain. Email scope will be set automatically.""" ).tag(config=True) @default("username_from_email") def _username_from_email_default(self): return False exclude_tokens = List( help="""Exclude tokens from being passed into user environments when they start notebooks, Terminals, etc.""" ).tag(config=True) def _exclude_tokens_default(self): return ['auth.globus.org', 'groups.api.globus.org'] def _scope_default(self): scopes = [ 'openid', 'profile', 'urn:globus:auth:scope:transfer.api.globus.org:all', ] if self.allowed_globus_groups or self.admin_globus_groups: scopes.append( 'urn:globus:auth:scope:groups.api.globus.org:view_my_groups_and_memberships' ) if self.username_from_email: scopes.append('email') return scopes globus_local_endpoint = Unicode( help="""If Jupyterhub is also a Globus endpoint, its endpoint id can be specified here.""" ).tag(config=True) def _globus_local_endpoint_default(self): return os.getenv('GLOBUS_LOCAL_ENDPOINT', '') revoke_tokens_on_logout = Bool( help="""Revoke tokens so they cannot be used again. Single-user servers MUST be restarted after logout in order to get a fresh working set of tokens.""" ).tag(config=True) def _revoke_tokens_on_logout_default(self): return False allowed_globus_groups = Set( help="""Allow members of defined Globus Groups to access JupyterHub. Users in an admin Globus Group are also automatically allowed. Groups are specified with their UUIDs. Setting this will add the Globus Groups scope.""" ).tag(config=True) admin_globus_groups = Set( help="""Set members of defined Globus Groups as JupyterHub admin users. These users are automatically allowed to login to JupyterHub. Groups are specified with their UUIDs. Setting this will add the Globus Groups scope.""" ).tag(config=True) async def pre_spawn_start(self, user, spawner): """Add tokens to the spawner whenever the spawner starts a notebook. This will allow users to create a transfer client: globus-sdk-python.readthedocs.io/en/stable/tutorial/#tutorial-step4 """ spawner.environment['GLOBUS_LOCAL_ENDPOINT'] = self.globus_local_endpoint state = await user.get_auth_state() if state: globus_data = base64.b64encode(pickle.dumps(state)) spawner.environment['GLOBUS_DATA'] = globus_data.decode('utf-8') async def authenticate(self, handler, data=None): """ Authenticate with globus.org. Usernames (and therefore Jupyterhub accounts) will correspond to a Globus User ID, so [email protected] will have the 'foouser' account in Jupyterhub. """ # Complete login and exchange the code for tokens. params = dict( redirect_uri=self.get_callback_url(handler), code=handler.get_argument("code"), grant_type='authorization_code', ) req = HTTPRequest( self.token_url, method="POST", headers=self.get_client_credential_headers(), body=urllib.parse.urlencode(params), ) token_json = await self.fetch(req) # Fetch user info at Globus's oauth2/userinfo/ HTTP endpoint to get the username user_headers = self.get_default_headers() user_headers['Authorization'] = 'Bearer {}'.format(token_json['access_token']) req = HTTPRequest(self.userdata_url, method='GET', headers=user_headers) user_resp = await self.fetch(req) username = self.get_username(user_resp) # Each token should have these attributes. Resource server is optional, # and likely won't be present. token_attrs = [ 'expires_in', 'resource_server', 'scope', 'token_type', 'refresh_token', 'access_token', ] # The Auth Token is a bit special, it comes back at the top level with the # id token. The id token has some useful information in it, but nothing that # can't be retrieved with an Auth token. # Repackage the Auth token into a dict that looks like the other tokens auth_token_dict = { attr_name: token_json.get(attr_name) for attr_name in token_attrs } # Make sure only the essentials make it into tokens. Other items, such as 'state' are # not needed after authentication and can be discarded. other_tokens = [ {attr_name: token_dict.get(attr_name) for attr_name in token_attrs} for token_dict in token_json['other_tokens'] ] tokens = other_tokens + [auth_token_dict] # historically, tokens have been organized by resource server for convenience. # If multiple scopes are requested from the same resource server, they will be # combined into a single token from Globus Auth. by_resource_server = { token_dict['resource_server']: token_dict for token_dict in tokens if token_dict['resource_server'] not in self.exclude_tokens } user_info = { 'name': username, 'auth_state': { 'client_id': self.client_id, 'tokens': by_resource_server, }, } use_globus_groups = False user_allowed = False if self.allowed_globus_groups or self.admin_globus_groups: # If any of these configurations are set, user must be in the allowed or admin Globus Group use_globus_groups = True user_group_ids = set() # Get Groups access token, may not be in dict headed to auth state for token_dict in tokens: if token_dict['resource_server'] == 'groups.api.globus.org': groups_token = token_dict['access_token'] # Get list of user's Groups groups_headers = self.get_default_headers() groups_headers['Authorization'] = 'Bearer {}'.format(groups_token) req = HTTPRequest( self.globus_groups_url, method='GET', headers=groups_headers ) groups_resp = await self.fetch(req) # Build set of Group IDs for group in groups_resp: user_group_ids.add(group['id']) if user_group_ids & self.allowed_globus_groups: user_allowed = True if self.admin_globus_groups: # Admin users are being managed via Globus Groups # Default to False user_info['admin'] = False if user_group_ids & self.admin_globus_groups: # User is an admin, admins allowed by default user_allowed = user_info['admin'] = True if user_allowed or not use_globus_groups: return user_info else: self.log.warning('{} not in an allowed Globus Group'.format(username)) return None def get_username(self, user_data): # It's possible for identity provider domains to be namespaced # https://docs.globus.org/api/auth/specification/#identity_provider_namespaces # noqa username_field = 'preferred_username' if self.username_from_email: username_field = 'email' username, domain = user_data.get(username_field).split('@', 1) if self.identity_provider and domain != self.identity_provider: raise HTTPError( 403, 'This site is restricted to {} accounts. Please link your {}' ' account at {}.'.format( self.identity_provider, self.identity_provider, 'globus.org/app/account', ), ) return username def get_default_headers(self): return {"Accept": "application/json", "User-Agent": "JupyterHub"} def get_client_credential_headers(self): headers = self.get_default_headers() b64key = base64.b64encode( bytes("{}:{}".format(self.client_id, self.client_secret), "utf8") ) headers["Authorization"] = "Basic {}".format(b64key.decode("utf8")) return headers async def revoke_service_tokens(self, services): """Revoke live Globus access and refresh tokens. Revoking inert or non-existent tokens does nothing. Services are defined by dicts returned by tokens.by_resource_server, for example: services = { 'transfer.api.globus.org': {'access_token': 'token'}, ... <Additional services>... } """ access_tokens = [ token_dict.get('access_token') for token_dict in services.values() ] refresh_tokens = [ token_dict.get('refresh_token') for token_dict in services.values() ] all_tokens = [tok for tok in access_tokens + refresh_tokens if tok is not None] for token in all_tokens: req = HTTPRequest( self.revocation_url, method="POST", headers=self.get_client_credential_headers(), body=urllib.parse.urlencode({'token': token}), ) await self.fetch(req)
class BinderHub(Application): """An Application for starting a builder.""" @default('log_level') def _log_level(self): return logging.INFO aliases = { 'log-level': 'Application.log_level', 'f': 'BinderHub.config_file', 'config': 'BinderHub.config_file', 'port': 'BinderHub.port', } flags = { 'debug': ({ 'BinderHub': { 'debug': True } }, "Enable debug HTTP serving & debug logging") } config_file = Unicode('binderhub_config.py', help=""" Config file to load. If a relative path is provided, it is taken relative to current directory """, config=True) google_analytics_code = Unicode(None, allow_none=True, help=""" The Google Analytics code to use on the main page. Note that we'll respect Do Not Track settings, despite the fact that GA does not. We will not load the GA scripts on browsers with DNT enabled. """, config=True) google_analytics_domain = Unicode('auto', help=""" The Google Analytics domain to use on the main page. By default this is set to 'auto', which sets it up for current domain and all subdomains. This can be set to a more restrictive domain here for better privacy """, config=True) about_message = Unicode('', help=""" Additional message to display on the about page. Will be directly inserted into the about page's source so you can use raw HTML. """, config=True) banner_message = Unicode('', help=""" Message to display in a banner on all pages. The value will be inserted "as is" into a HTML <div> element with grey background, located at the top of the BinderHub pages. Raw HTML is supported. """, config=True) extra_footer_scripts = Dict({}, help=""" Extra bits of JavaScript that should be loaded in footer of each page. Only the values are set up as scripts. Keys are used only for sorting. Omit the <script> tag. This should be primarily used for analytics code. """, config=True) base_url = Unicode('/', help="The base URL of the entire application", config=True) @validate('base_url') def _valid_base_url(self, proposal): if not proposal.value.startswith('/'): proposal.value = '/' + proposal.value if not proposal.value.endswith('/'): proposal.value = proposal.value + '/' return proposal.value badge_base_url = Union(trait_types=[Unicode(), Callable()], help=""" Base URL to use when generating launch badges. Can also be a function that is passed the current handler and returns the badge base URL, or "" for the default. For example, you could get the badge_base_url from a custom HTTP header, the Referer header, or from a request parameter """, config=True) @default('badge_base_url') def _badge_base_url_default(self): return '' @validate('badge_base_url') def _valid_badge_base_url(self, proposal): if callable(proposal.value): return proposal.value # add a trailing slash only when a value is set if proposal.value and not proposal.value.endswith('/'): proposal.value = proposal.value + '/' return proposal.value cors_allow_origin = Unicode("", help=""" Origins that can access the BinderHub API. Sets the Access-Control-Allow-Origin header in the spawned notebooks. Set to '*' to allow any origin to access spawned notebook servers. See also BinderSpawner.cors_allow_origin in the binderhub spawner mixin for setting this property on the spawned notebooks. """, config=True) auth_enabled = Bool(False, help="""If JupyterHub authentication enabled, require user to login (don't create temporary users during launch) and start the new server for the logged in user.""", config=True) port = Integer(8585, help=""" Port for the builder to listen on. """, config=True) appendix = Unicode( help=""" Appendix to pass to repo2docker A multi-line string of Docker directives to run. Since the build context cannot be affected, ADD will typically not be useful. This should be a Python string template. It will be formatted with at least the following names available: - binder_url: the shareable URL for the current image (e.g. for sharing links to the current Binder) - repo_url: the repository URL used to build the image """, config=True, ) sticky_builds = Bool( False, help=""" Attempt to assign builds for the same repository to the same node. In order to speed up re-builds of a repository all its builds will be assigned to the same node in the cluster. Note: This feature only works if you also enable docker-in-docker support. """, config=True, ) use_registry = Bool(True, help=""" Set to true to push images to a registry & check for images in registry. Set to false to use only local docker images. Useful when running in a single node. """, config=True) build_class = Type(Build, help=""" The class used to build repo2docker images. Must inherit from binderhub.build.Build """, config=True) registry_class = Type(DockerRegistry, help=""" The class used to Query a Docker registry. Must inherit from binderhub.registry.DockerRegistry """, config=True) per_repo_quota = Integer( 0, help=""" Maximum number of concurrent users running from a given repo. Limits the amount of Binder that can be consumed by a single repo. 0 (default) means no quotas. """, config=True, ) pod_quota = Integer( None, help=""" The number of concurrent pods this hub has been designed to support. This quota is used as an indication for how much above or below the design capacity a hub is running. Attempts to launch new pods once the quota has been reached will fail. The default corresponds to no quota, 0 means the hub can't accept pods (maybe because it is in maintenance mode), and any positive integer sets the quota. """, allow_none=True, config=True, ) per_repo_quota_higher = Integer( 0, help=""" Maximum number of concurrent users running from a higher-quota repo. Limits the amount of Binder that can be consumed by a single repo. This quota is a second limit for repos with special status. See the `high_quota_specs` parameter of RepoProvider classes for usage. 0 (default) means no quotas. """, config=True, ) log_tail_lines = Integer( 100, help=""" Limit number of log lines to show when connecting to an already running build. """, config=True, ) push_secret = Unicode('binder-build-docker-config', allow_none=True, help=""" A kubernetes secret object that provides credentials for pushing built images. """, config=True) image_prefix = Unicode("", help=""" Prefix for all built docker images. If you are pushing to gcr.io, this would start with: gcr.io/<your-project-name>/ Set according to whatever registry you are pushing to. Defaults to "", which is probably not what you want :) """, config=True) build_memory_request = ByteSpecification( 0, help=""" Amount of memory to request when scheduling a build 0 reserves no memory. This is used as the request for the pod that is spawned to do the building, even though the pod itself will not be using that much memory since the docker building is happening outside the pod. However, it makes kubernetes aware of the resources being used, and lets it schedule more intelligently. """, config=True, ) build_memory_limit = ByteSpecification( 0, help=""" Max amount of memory allocated for each image build process. 0 sets no limit. This is applied to the docker build itself via repo2docker, though it is also applied to our pod that submits the build, even though that pod will rarely consume much memory. Still, it makes it easier to see the resource limits in place via kubernetes. """, config=True, ) debug = Bool(False, help=""" Turn on debugging. """, config=True) build_docker_host = Unicode("/var/run/docker.sock", config=True, help=""" The docker URL repo2docker should use to build the images. Currently, only paths are supported, and they are expected to be available on all the hosts. """) @validate('build_docker_host') def docker_build_host_validate(self, proposal): parts = urlparse(proposal.value) if parts.scheme != 'unix' or parts.netloc != '': raise TraitError( "Only unix domain sockets on same node are supported for build_docker_host" ) return proposal.value build_docker_config = Dict(None, allow_none=True, help=""" A dict which will be merged into the .docker/config.json of the build container (repo2docker) Here, you could for example pass proxy settings as described here: https://docs.docker.com/network/proxy/#configure-the-docker-client Note: if you provide your own push_secret, this values wont have an effect, as the push_secrets will overwrite .docker/config.json In this case, make sure that you include your config in your push_secret """, config=True) hub_api_token = Unicode( help="""API token for talking to the JupyterHub API""", config=True, ) @default('hub_api_token') def _default_hub_token(self): return os.environ.get('JUPYTERHUB_API_TOKEN', '') hub_url = Unicode( help=""" The base URL of the JupyterHub instance where users will run. e.g. https://hub.mybinder.org/ """, config=True, ) hub_url_local = Unicode( help=""" The base URL of the JupyterHub instance for local/internal traffic If local/internal network connections from the BinderHub process should access JupyterHub using a different URL than public/external traffic set this, default is hub_url """, config=True, ) @default('hub_url_local') def _default_hub_url_local(self): return self.hub_url @validate('hub_url', 'hub_url_local') def _add_slash(self, proposal): """trait validator to ensure hub_url ends with a trailing slash""" if proposal.value is not None and not proposal.value.endswith('/'): return proposal.value + '/' return proposal.value build_namespace = Unicode(help=""" Kubernetes namespace to spawn build pods in. Note that the push_secret must refer to a secret in this namespace. """, config=True) @default('build_namespace') def _default_build_namespace(self): return os.environ.get('BUILD_NAMESPACE', 'default') build_image = Unicode('quay.io/jupyterhub/repo2docker:2021.08.0', help=""" The repo2docker image to be used for doing builds """, config=True) build_node_selector = Dict({}, config=True, help=""" Select the node where build pod runs on. """) repo_providers = Dict( { 'gh': GitHubRepoProvider, 'gist': GistRepoProvider, 'git': GitRepoProvider, 'gl': GitLabRepoProvider, 'zenodo': ZenodoProvider, 'figshare': FigshareProvider, 'hydroshare': HydroshareProvider, 'dataverse': DataverseProvider, }, config=True, help=""" List of Repo Providers to register and try """) @validate('repo_providers') def _validate_repo_providers(self, proposal): """trait validator to ensure there is at least one repo provider""" if not proposal.value: raise TraitError("Please provide at least one repo provider") if any([ not issubclass(provider, RepoProvider) for provider in proposal.value.values() ]): raise TraitError( "Repository providers should inherit from 'binderhub.RepoProvider'" ) return proposal.value concurrent_build_limit = Integer( 32, config=True, help="""The number of concurrent builds to allow.""") executor_threads = Integer( 5, config=True, help="""The number of threads to use for blocking calls Should generally be a small number because we don't care about high concurrency here, just not blocking the webserver. This executor is not used for long-running tasks (e.g. builds). """, ) build_cleanup_interval = Integer( 60, config=True, help= """Interval (in seconds) for how often stopped build pods will be deleted.""" ) build_max_age = Integer(3600 * 4, config=True, help="""Maximum age of builds Builds that are still running longer than this will be killed. """) build_token_check_origin = Bool( True, config=True, help="""Whether to validate build token origin. False disables the origin check. """) build_token_expires_seconds = Integer( 300, config=True, help="""Expiry (in seconds) of build tokens These are generally only used to authenticate a single request from a page, so should be short-lived. """, ) build_token_secret = Union( [Unicode(), Bytes()], config=True, help="""Secret used to sign build tokens Lightweight validation of same-origin requests """, ) @validate("build_token_secret") def _validate_build_token_secret(self, proposal): if isinstance(proposal.value, str): # allow hex string for text-only input formats return a2b_hex(proposal.value) return proposal.value @default("build_token_secret") def _default_build_token_secret(self): if os.environ.get("BINDERHUB_BUILD_TOKEN_SECRET"): return a2b_hex(os.environ["BINDERHUB_BUILD_TOKEN_SECRET"]) app_log.warning( "Generating random build token secret." " Set BinderHub.build_token_secret to avoid this warning.") return secrets.token_bytes(32) # FIXME: Come up with a better name for it? builder_required = Bool(True, config=True, help=""" If binderhub should try to continue to run without a working build infrastructure. Build infrastructure is kubernetes cluster + docker. This is useful for pure HTML/CSS/JS local development. """) ban_networks = Dict( config=True, help=""" Dict of networks from which requests should be rejected with 403 Keys are CIDR notation (e.g. '1.2.3.4/32'), values are a label used in log / error messages. CIDR strings will be parsed with `ipaddress.ip_network()`. """, ) @validate("ban_networks") def _cast_ban_networks(self, proposal): """Cast CIDR strings to IPv[4|6]Network objects""" networks = {} for cidr, message in proposal.value.items(): networks[ipaddress.ip_network(cidr)] = message return networks ban_networks_min_prefix_len = Integer( 1, help="The shortest prefix in ban_networks", ) @observe("ban_networks") def _update_prefix_len(self, change): if not change.new: min_len = 1 else: min_len = min(net.prefixlen for net in change.new) self.ban_networks_min_prefix_len = min_len or 1 tornado_settings = Dict(config=True, help=""" additional settings to pass through to tornado. can include things like additional headers, etc. """) template_variables = Dict( config=True, help="Extra variables to supply to jinja templates when rendering.", ) template_path = Unicode( help= "Path to search for custom jinja templates, before using the default templates.", config=True, ) @default('template_path') def _template_path_default(self): return os.path.join(HERE, 'templates') extra_static_path = Unicode( help='Path to search for extra static files.', config=True, ) extra_static_url_prefix = Unicode( '/extra_static/', help='Url prefix to serve extra static files.', config=True, ) normalized_origin = Unicode( '', config=True, help= 'Origin to use when emitting events. Defaults to hostname of request when empty' ) allowed_metrics_ips = Set( help= 'List of IPs or networks allowed to GET /metrics. Defaults to all.', config=True) @staticmethod def add_url_prefix(prefix, handlers): """add a url prefix to handlers""" for i, tup in enumerate(handlers): lis = list(tup) lis[0] = url_path_join(prefix, tup[0]) handlers[i] = tuple(lis) return handlers def init_pycurl(self): try: AsyncHTTPClient.configure( "tornado.curl_httpclient.CurlAsyncHTTPClient") except ImportError as e: self.log.debug( "Could not load pycurl: %s\npycurl is recommended if you have a large number of users.", e) # set max verbosity of curl_httpclient at INFO # because debug-logging from curl_httpclient # includes every full request and response if self.log_level < logging.INFO: curl_log = logging.getLogger('tornado.curl_httpclient') curl_log.setLevel(logging.INFO) def initialize(self, *args, **kwargs): """Load configuration settings.""" super().initialize(*args, **kwargs) self.load_config_file(self.config_file) # hook up tornado logging if self.debug: self.log_level = logging.DEBUG tornado.options.options.logging = logging.getLevelName(self.log_level) tornado.log.enable_pretty_logging() self.log = tornado.log.app_log self.init_pycurl() # initialize kubernetes config if self.builder_required: try: kubernetes.config.load_incluster_config() except kubernetes.config.ConfigException: kubernetes.config.load_kube_config() self.tornado_settings[ "kubernetes_client"] = self.kube_client = kubernetes.client.CoreV1Api( ) # times 2 for log + build threads self.build_pool = ThreadPoolExecutor(self.concurrent_build_limit * 2) # default executor for asyncifying blocking calls (e.g. to kubernetes, docker). # this should not be used for long-running requests self.executor = ThreadPoolExecutor(self.executor_threads) jinja_options = dict(autoescape=True, ) template_paths = [self.template_path] base_template_path = self._template_path_default() if base_template_path not in template_paths: # add base templates to the end, so they are looked up at last after custom templates template_paths.append(base_template_path) loader = ChoiceLoader([ # first load base templates with prefix PrefixLoader({'templates': FileSystemLoader([base_template_path])}, '/'), # load all templates FileSystemLoader(template_paths) ]) jinja_env = Environment(loader=loader, **jinja_options) if self.use_registry: registry = self.registry_class(parent=self) else: registry = None self.launcher = Launcher( parent=self, hub_url=self.hub_url, hub_url_local=self.hub_url_local, hub_api_token=self.hub_api_token, create_user=not self.auth_enabled, ) self.event_log = EventLog(parent=self) for schema_file in glob(os.path.join(HERE, 'event-schemas', '*.json')): with open(schema_file) as f: self.event_log.register_schema(json.load(f)) self.tornado_settings.update({ "log_function": log_request, "push_secret": self.push_secret, "image_prefix": self.image_prefix, "debug": self.debug, "hub_url": self.hub_url, "launcher": self.launcher, "appendix": self.appendix, "ban_networks": self.ban_networks, "ban_networks_min_prefix_len": self.ban_networks_min_prefix_len, "build_namespace": self.build_namespace, "build_image": self.build_image, "build_node_selector": self.build_node_selector, "build_pool": self.build_pool, "build_token_check_origin": self.build_token_check_origin, "build_token_secret": self.build_token_secret, "build_token_expires_seconds": self.build_token_expires_seconds, "sticky_builds": self.sticky_builds, "log_tail_lines": self.log_tail_lines, "pod_quota": self.pod_quota, "per_repo_quota": self.per_repo_quota, "per_repo_quota_higher": self.per_repo_quota_higher, "repo_providers": self.repo_providers, "rate_limiter": RateLimiter(parent=self), "use_registry": self.use_registry, "build_class": self.build_class, "registry": registry, "traitlets_config": self.config, "google_analytics_code": self.google_analytics_code, "google_analytics_domain": self.google_analytics_domain, "about_message": self.about_message, "banner_message": self.banner_message, "extra_footer_scripts": self.extra_footer_scripts, "jinja2_env": jinja_env, "build_memory_limit": self.build_memory_limit, "build_memory_request": self.build_memory_request, "build_docker_host": self.build_docker_host, "build_docker_config": self.build_docker_config, "base_url": self.base_url, "badge_base_url": self.badge_base_url, "static_path": os.path.join(HERE, "static"), "static_url_prefix": url_path_join(self.base_url, "static/"), "template_variables": self.template_variables, "executor": self.executor, "auth_enabled": self.auth_enabled, "event_log": self.event_log, "normalized_origin": self.normalized_origin, "allowed_metrics_ips": set(map(ipaddress.ip_network, self.allowed_metrics_ips)) }) if self.auth_enabled: self.tornado_settings['cookie_secret'] = os.urandom(32) if self.cors_allow_origin: self.tornado_settings.setdefault( 'headers', {})['Access-Control-Allow-Origin'] = self.cors_allow_origin handlers = [ (r'/metrics', MetricsHandler), (r'/versions', VersionHandler), (r"/build/([^/]+)/(.+)", BuildHandler), (r"/v2/([^/]+)/(.+)", ParameterizedMainHandler), (r"/repo/([^/]+)/([^/]+)(/.*)?", LegacyRedirectHandler), (r'/~([^/]+/.*)', UserRedirectHandler), # for backward-compatible mybinder.org badge URLs # /assets/images/badge.svg (r'/assets/(images/badge\.svg)', tornado.web.StaticFileHandler, { 'path': self.tornado_settings['static_path'] }), # /badge.svg (r'/(badge\.svg)', tornado.web.StaticFileHandler, { 'path': os.path.join(self.tornado_settings['static_path'], 'images') }), # /badge_logo.svg (r'/(badge\_logo\.svg)', tornado.web.StaticFileHandler, { 'path': os.path.join(self.tornado_settings['static_path'], 'images') }), # /logo_social.png (r'/(logo\_social\.png)', tornado.web.StaticFileHandler, { 'path': os.path.join(self.tornado_settings['static_path'], 'images') }), # /favicon_XXX.ico (r'/(favicon\_fail\.ico)', tornado.web.StaticFileHandler, { 'path': os.path.join(self.tornado_settings['static_path'], 'images') }), (r'/(favicon\_success\.ico)', tornado.web.StaticFileHandler, { 'path': os.path.join(self.tornado_settings['static_path'], 'images') }), (r'/(favicon\_building\.ico)', tornado.web.StaticFileHandler, { 'path': os.path.join(self.tornado_settings['static_path'], 'images') }), (r'/about', AboutHandler), (r'/health', HealthHandler, { 'hub_url': self.hub_url_local }), (r'/_config', ConfigHandler), (r'/', MainHandler), (r'.*', Custom404), ] handlers = self.add_url_prefix(self.base_url, handlers) if self.extra_static_path: handlers.insert(-1, (re.escape( url_path_join(self.base_url, self.extra_static_url_prefix)) + r"(.*)", tornado.web.StaticFileHandler, { 'path': self.extra_static_path })) if self.auth_enabled: oauth_redirect_uri = os.getenv('JUPYTERHUB_OAUTH_CALLBACK_URL') or \ url_path_join(self.base_url, 'oauth_callback') oauth_redirect_uri = urlparse(oauth_redirect_uri).path handlers.insert( -1, (re.escape(oauth_redirect_uri), HubOAuthCallbackHandler)) self.tornado_app = tornado.web.Application(handlers, **self.tornado_settings) def stop(self): self.http_server.stop() self.build_pool.shutdown() async def watch_build_pods(self): """Watch build pods Every build_cleanup_interval: - delete stopped build pods - delete running build pods older than build_max_age """ while True: try: await asyncio.wrap_future( self.executor.submit(lambda: Build.cleanup_builds( self.kube_client, self.build_namespace, self.build_max_age, ))) except Exception: app_log.exception("Failed to cleanup build pods") await asyncio.sleep(self.build_cleanup_interval) def start(self, run_loop=True): self.log.info("BinderHub starting on port %i", self.port) self.http_server = HTTPServer( self.tornado_app, xheaders=True, ) self.http_server.listen(self.port) if self.builder_required: asyncio.ensure_future(self.watch_build_pods()) if run_loop: tornado.ioloop.IOLoop.current().start()
class EnterpriseGatewayConfigMixin(Configurable): # Server IP / PORT binding port_env = 'EG_PORT' port_default_value = 8888 port = Integer(port_default_value, config=True, help='Port on which to listen (EG_PORT env var)') @default('port') def port_default(self): return int(os.getenv(self.port_env, os.getenv('KG_PORT', self.port_default_value))) port_retries_env = 'EG_PORT_RETRIES' port_retries_default_value = 50 port_retries = Integer(port_retries_default_value, config=True, help="""Number of ports to try if the specified port is not available (EG_PORT_RETRIES env var)""") @default('port_retries') def port_retries_default(self): return int(os.getenv(self.port_retries_env, os.getenv('KG_PORT_RETRIES', self.port_retries_default_value))) ip_env = 'EG_IP' ip_default_value = '127.0.0.1' ip = Unicode(ip_default_value, config=True, help='IP address on which to listen (EG_IP env var)') @default('ip') def ip_default(self): return os.getenv(self.ip_env, os.getenv('KG_IP', self.ip_default_value)) # Base URL base_url_env = 'EG_BASE_URL' base_url_default_value = '/' base_url = Unicode(base_url_default_value, config=True, help='The base path for mounting all API resources (EG_BASE_URL env var)') @default('base_url') def base_url_default(self): return os.getenv(self.base_url_env, os.getenv('KG_BASE_URL', self.base_url_default_value)) # Token authorization auth_token_env = 'EG_AUTH_TOKEN' auth_token = Unicode(config=True, help='Authorization token required for all requests (EG_AUTH_TOKEN env var)') @default('auth_token') def _auth_token_default(self): return os.getenv(self.auth_token_env, os.getenv('KG_AUTH_TOKEN', '')) # Begin CORS headers allow_credentials_env = 'EG_ALLOW_CREDENTIALS' allow_credentials = Unicode(config=True, help='Sets the Access-Control-Allow-Credentials header. (EG_ALLOW_CREDENTIALS env var)') @default('allow_credentials') def allow_credentials_default(self): return os.getenv(self.allow_credentials_env, os.getenv('KG_ALLOW_CREDENTIALS', '')) allow_headers_env = 'EG_ALLOW_HEADERS' allow_headers = Unicode(config=True, help='Sets the Access-Control-Allow-Headers header. (EG_ALLOW_HEADERS env var)') @default('allow_headers') def allow_headers_default(self): return os.getenv(self.allow_headers_env, os.getenv('KG_ALLOW_HEADERS', '')) allow_methods_env = 'EG_ALLOW_METHODS' allow_methods = Unicode(config=True, help='Sets the Access-Control-Allow-Methods header. (EG_ALLOW_METHODS env var)') @default('allow_methods') def allow_methods_default(self): return os.getenv(self.allow_methods_env, os.getenv('KG_ALLOW_METHODS', '')) allow_origin_env = 'EG_ALLOW_ORIGIN' allow_origin = Unicode(config=True, help='Sets the Access-Control-Allow-Origin header. (EG_ALLOW_ORIGIN env var)') @default('allow_origin') def allow_origin_default(self): return os.getenv(self.allow_origin_env, os.getenv('KG_ALLOW_ORIGIN', '')) expose_headers_env = 'EG_EXPOSE_HEADERS' expose_headers = Unicode(config=True, help='Sets the Access-Control-Expose-Headers header. (EG_EXPOSE_HEADERS env var)') @default('expose_headers') def expose_headers_default(self): return os.getenv(self.expose_headers_env, os.getenv('KG_EXPOSE_HEADERS', '')) trust_xheaders_env = 'EG_TRUST_XHEADERS' trust_xheaders = CBool(False, config=True, help="""Use x-* header values for overriding the remote-ip, useful when application is behing a proxy. (EG_TRUST_XHEADERS env var)""") @default('trust_xheaders') def trust_xheaders_default(self): return strtobool(os.getenv(self.trust_xheaders_env, os.getenv('KG_TRUST_XHEADERS', 'False'))) certfile_env = 'EG_CERTFILE' certfile = Unicode(None, config=True, allow_none=True, help='The full path to an SSL/TLS certificate file. (EG_CERTFILE env var)') @default('certfile') def certfile_default(self): return os.getenv(self.certfile_env, os.getenv('KG_CERTFILE')) keyfile_env = 'EG_KEYFILE' keyfile = Unicode(None, config=True, allow_none=True, help='The full path to a private key file for usage with SSL/TLS. (EG_KEYFILE env var)') @default('keyfile') def keyfile_default(self): return os.getenv(self.keyfile_env, os.getenv('KG_KEYFILE')) client_ca_env = 'EG_CLIENT_CA' client_ca = Unicode(None, config=True, allow_none=True, help="""The full path to a certificate authority certificate for SSL/TLS client authentication. (EG_CLIENT_CA env var)""") @default('client_ca') def client_ca_default(self): return os.getenv(self.client_ca_env, os.getenv('KG_CLIENT_CA')) ssl_version_env = 'EG_SSL_VERSION' ssl_version_default_value = ssl.PROTOCOL_TLSv1_2 ssl_version = Integer(None, config=True, allow_none=True, help="""Sets the SSL version to use for the web socket connection. (EG_SSL_VERSION env var)""") @default('ssl_version') def ssl_version_default(self): ssl_from_env = os.getenv(self.ssl_version_env, os.getenv('KG_SSL_VERSION')) return ssl_from_env if ssl_from_env is None else int(ssl_from_env) max_age_env = 'EG_MAX_AGE' max_age = Unicode(config=True, help='Sets the Access-Control-Max-Age header. (EG_MAX_AGE env var)') @default('max_age') def max_age_default(self): return os.getenv(self.max_age_env, os.getenv('KG_MAX_AGE', '')) # End CORS headers max_kernels_env = 'EG_MAX_KERNELS' max_kernels = Integer(None, config=True, allow_none=True, help="""Limits the number of kernel instances allowed to run by this gateway. Unbounded by default. (EG_MAX_KERNELS env var)""") @default('max_kernels') def max_kernels_default(self): val = os.getenv(self.max_kernels_env, os.getenv('KG_MAX_KERNELS')) return val if val is None else int(val) default_kernel_name_env = 'EG_DEFAULT_KERNEL_NAME' default_kernel_name = Unicode(config=True, help='Default kernel name when spawning a kernel (EG_DEFAULT_KERNEL_NAME env var)') @default('default_kernel_name') def default_kernel_name_default(self): # defaults to Jupyter's default kernel name on empty string return os.getenv(self.default_kernel_name_env, os.getenv('KG_DEFAULT_KERNEL_NAME', '')) list_kernels_env = 'EG_LIST_KERNELS' list_kernels = Bool(config=True, help="""Permits listing of the running kernels using API endpoints /api/kernels and /api/sessions. (EG_LIST_KERNELS env var) Note: Jupyter Notebook allows this by default but Jupyter Enterprise Gateway does not.""") @default('list_kernels') def list_kernels_default(self): return os.getenv(self.list_kernels_env, os.getenv('KG_LIST_KERNELS', 'False')).lower() == 'true' env_whitelist_env = 'EG_ENV_WHITELIST' env_whitelist = List(config=True, help="""Environment variables allowed to be set when a client requests a new kernel. (EG_ENV_WHITELIST env var)""") @default('env_whitelist') def env_whitelist_default(self): return os.getenv(self.env_whitelist_env, os.getenv('KG_ENV_WHITELIST', '')).split(',') env_process_whitelist_env = 'EG_ENV_PROCESS_WHITELIST' env_process_whitelist = List(config=True, help="""Environment variables allowed to be inherited from the spawning process by the kernel. (EG_ENV_PROCESS_WHITELIST env var)""") @default('env_process_whitelist') def env_process_whitelist_default(self): return os.getenv(self.env_process_whitelist_env, os.getenv('KG_ENV_PROCESS_WHITELIST', '')).split(',') kernel_headers_env = 'EG_KERNEL_HEADERS' kernel_headers = List(config=True, help="""Request headers to make available to kernel launch framework. (EG_KERNEL_HEADERS env var)""") @default('kernel_headers') def kernel_headers_default(self): default_headers = os.getenv(self.kernel_headers_env) return default_headers.split(',') if default_headers else [] # Remote hosts remote_hosts_env = 'EG_REMOTE_HOSTS' remote_hosts_default_value = 'localhost' remote_hosts = List(default_value=[remote_hosts_default_value], config=True, help="""Bracketed comma-separated list of hosts on which DistributedProcessProxy kernels will be launched e.g., ['host1','host2']. (EG_REMOTE_HOSTS env var - non-bracketed, just comma-separated)""") @default('remote_hosts') def remote_hosts_default(self): return os.getenv(self.remote_hosts_env, self.remote_hosts_default_value).split(',') # Yarn endpoint yarn_endpoint_env = 'EG_YARN_ENDPOINT' yarn_endpoint = Unicode(None, config=True, allow_none=True, help="""The http url specifying the YARN Resource Manager. Note: If this value is NOT set, the YARN library will use the files within the local HADOOP_CONFIG_DIR to determine the active resource manager. (EG_YARN_ENDPOINT env var)""") @default('yarn_endpoint') def yarn_endpoint_default(self): return os.getenv(self.yarn_endpoint_env) # Alt Yarn endpoint alt_yarn_endpoint_env = 'EG_ALT_YARN_ENDPOINT' alt_yarn_endpoint = Unicode(None, config=True, allow_none=True, help="""The http url specifying the alternate YARN Resource Manager. This value should be set when YARN Resource Managers are configured for high availability. Note: If both YARN endpoints are NOT set, the YARN library will use the files within the local HADOOP_CONFIG_DIR to determine the active resource manager. (EG_ALT_YARN_ENDPOINT env var)""") @default('alt_yarn_endpoint') def alt_yarn_endpoint_default(self): return os.getenv(self.alt_yarn_endpoint_env) yarn_endpoint_security_enabled_env = 'EG_YARN_ENDPOINT_SECURITY_ENABLED' yarn_endpoint_security_enabled_default_value = False yarn_endpoint_security_enabled = Bool(yarn_endpoint_security_enabled_default_value, config=True, help="""Is YARN Kerberos/SPNEGO Security enabled (True/False). (EG_YARN_ENDPOINT_SECURITY_ENABLED env var)""") @default('yarn_endpoint_security_enabled') def yarn_endpoint_security_enabled_default(self): return bool(os.getenv(self.yarn_endpoint_security_enabled_env, self.yarn_endpoint_security_enabled_default_value)) # Conductor endpoint conductor_endpoint_env = 'EG_CONDUCTOR_ENDPOINT' conductor_endpoint_default_value = None conductor_endpoint = Unicode(conductor_endpoint_default_value, allow_none=True, config=True, help="""The http url for accessing the Conductor REST API. (EG_CONDUCTOR_ENDPOINT env var)""") @default('conductor_endpoint') def conductor_endpoint_default(self): return os.getenv(self.conductor_endpoint_env, self.conductor_endpoint_default_value) _log_formatter_cls = LogFormatter # traitlet default is LevelFormatter @default('log_format') def _default_log_format(self): """override default log format to include milliseconds""" return u"%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s]%(end_color)s %(message)s" # Impersonation enabled impersonation_enabled_env = 'EG_IMPERSONATION_ENABLED' impersonation_enabled = Bool(False, config=True, help="""Indicates whether impersonation will be performed during kernel launch. (EG_IMPERSONATION_ENABLED env var)""") @default('impersonation_enabled') def impersonation_enabled_default(self): return bool(os.getenv(self.impersonation_enabled_env, 'false').lower() == 'true') # Unauthorized users unauthorized_users_env = 'EG_UNAUTHORIZED_USERS' unauthorized_users_default_value = 'root' unauthorized_users = Set(default_value={unauthorized_users_default_value}, config=True, help="""Comma-separated list of user names (e.g., ['root','admin']) against which KERNEL_USERNAME will be compared. Any match (case-sensitive) will prevent the kernel's launch and result in an HTTP 403 (Forbidden) error. (EG_UNAUTHORIZED_USERS env var - non-bracketed, just comma-separated)""") @default('unauthorized_users') def unauthorized_users_default(self): return os.getenv(self.unauthorized_users_env, self.unauthorized_users_default_value).split(',') # Authorized users authorized_users_env = 'EG_AUTHORIZED_USERS' authorized_users = Set(config=True, help="""Comma-separated list of user names (e.g., ['bob','alice']) against which KERNEL_USERNAME will be compared. Any match (case-sensitive) will allow the kernel's launch, otherwise an HTTP 403 (Forbidden) error will be raised. The set of unauthorized users takes precedence. This option should be used carefully as it can dramatically limit who can launch kernels. (EG_AUTHORIZED_USERS env var - non-bracketed, just comma-separated)""") @default('authorized_users') def authorized_users_default(self): au_env = os.getenv(self.authorized_users_env) return au_env.split(',') if au_env is not None else [] # Port range port_range_env = 'EG_PORT_RANGE' port_range_default_value = "0..0" port_range = Unicode(port_range_default_value, config=True, help="""Specifies the lower and upper port numbers from which ports are created. The bounded values are separated by '..' (e.g., 33245..34245 specifies a range of 1000 ports to be randomly selected). A range of zero (e.g., 33245..33245 or 0..0) disables port-range enforcement. (EG_PORT_RANGE env var)""") @default('port_range') def port_range_default(self): return os.getenv(self.port_range_env, self.port_range_default_value) # Max Kernels per User max_kernels_per_user_env = 'EG_MAX_KERNELS_PER_USER' max_kernels_per_user_default_value = -1 max_kernels_per_user = Integer(max_kernels_per_user_default_value, config=True, help="""Specifies the maximum number of kernels a user can have active simultaneously. A value of -1 disables enforcement. (EG_MAX_KERNELS_PER_USER env var)""") @default('max_kernels_per_user') def max_kernels_per_user_default(self): return int(os.getenv(self.max_kernels_per_user_env, self.max_kernels_per_user_default_value)) ws_ping_interval_env = 'EG_WS_PING_INTERVAL_SECS' ws_ping_interval_default_value = 30 ws_ping_interval = Integer(ws_ping_interval_default_value, config=True, help="""Specifies the ping interval(in seconds) that should be used by zmq port associated withspawned kernels.Set this variable to 0 to disable ping mechanism. (EG_WS_PING_INTERVAL_SECS env var)""") @default('ws_ping_interval') def ws_ping_interval_default(self): return int(os.getenv(self.ws_ping_interval_env, self.ws_ping_interval_default_value)) # Dynamic Update Interval dynamic_config_interval_env = 'EG_DYNAMIC_CONFIG_INTERVAL' dynamic_config_interval_default_value = 0 dynamic_config_interval = Integer(dynamic_config_interval_default_value, min=0, config=True, help="""Specifies the number of seconds configuration files are polled for changes. A value of 0 or less disables dynamic config updates. (EG_DYNAMIC_CONFIG_INTERVAL env var)""") @default('dynamic_config_interval') def dynamic_config_interval_default(self): return int(os.getenv(self.dynamic_config_interval_env, self.dynamic_config_interval_default_value)) @observe('dynamic_config_interval') def dynamic_config_interval_changed(self, event): prev_val = event['old'] self.dynamic_config_interval = event['new'] if self.dynamic_config_interval != prev_val: # Values are different. Stop the current poller. If new value is > 0, start a poller. if self.dynamic_config_poller: self.dynamic_config_poller.stop() self.dynamic_config_poller = None if self.dynamic_config_interval <= 0: self.log.warning("Dynamic configuration updates have been disabled and cannot be re-enabled " "without restarting Enterprise Gateway!") # The interval has been changed, but still positive elif prev_val > 0 and hasattr(self, "init_dynamic_configs"): self.init_dynamic_configs() # Restart the poller dynamic_config_poller = None kernel_spec_manager = Instance("jupyter_client.kernelspec.KernelSpecManager", allow_none=True) kernel_spec_manager_class = Type( default_value="jupyter_client.kernelspec.KernelSpecManager", config=True, help=""" The kernel spec manager class to use. Must be a subclass of `jupyter_client.kernelspec.KernelSpecManager`. """ ) kernel_manager_class = Type( klass="enterprise_gateway.services.kernels.remotemanager.RemoteMappingKernelManager", default_value="enterprise_gateway.services.kernels.remotemanager.RemoteMappingKernelManager", config=True, help=""" The kernel manager class to use. Must be a subclass of `enterprise_gateway.services.kernels.RemoteMappingKernelManager`. """ ) kernel_session_manager_class = Type( klass="enterprise_gateway.services.sessions.kernelsessionmanager.KernelSessionManager", default_value="enterprise_gateway.services.sessions.kernelsessionmanager.FileKernelSessionManager", config=True, help=""" The kernel session manager class to use. Must be a subclass of `enterprise_gateway.services.sessions.KernelSessionManager`. """ )
class View(HasTraits): """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes. Don't use this class, use subclasses. Methods ------- spin flushes incoming results and registration state changes control methods spin, and requesting `ids` also ensures up to date wait wait on one or more msg_ids execution methods apply legacy: execute, run data movement push, pull, scatter, gather query methods get_result, queue_status, purge_results, result_status control methods abort, shutdown """ # flags block = Bool(False) track = Bool(False) targets = Any() history = List() outstanding = Set() results = Dict() client = Instance('ipyparallel.Client', allow_none=True) _socket = Instance('zmq.Socket', allow_none=True) _flag_names = List(['targets', 'block', 'track']) _in_sync_results = Bool(False) _targets = Any() _idents = Any() def __init__(self, client=None, socket=None, **flags): super(View, self).__init__(client=client, _socket=socket) self.results = client.results self.block = client.block self.executor = ViewExecutor(self) self.set_flags(**flags) assert not self.__class__ is View, "Don't use base View objects, use subclasses" def __repr__(self): strtargets = str(self.targets) if len(strtargets) > 16: strtargets = strtargets[:12] + '...]' return "<%s %s>" % (self.__class__.__name__, strtargets) def __len__(self): if isinstance(self.targets, list): return len(self.targets) elif isinstance(self.targets, int): return 1 else: return len(self.client) def set_flags(self, **kwargs): """set my attribute flags by keyword. Views determine behavior with a few attributes (`block`, `track`, etc.). These attributes can be set all at once by name with this method. Parameters ---------- block : bool whether to wait for results track : bool whether to create a MessageTracker to allow the user to safely edit after arrays and buffers during non-copying sends. """ for name, value in iteritems(kwargs): if name not in self._flag_names: raise KeyError("Invalid name: %r" % name) else: setattr(self, name, value) @contextmanager def temp_flags(self, **kwargs): """temporarily set flags, for use in `with` statements. See set_flags for permanent setting of flags Examples -------- >>> view.track=False ... >>> with view.temp_flags(track=True): ... ar = view.apply(dostuff, my_big_array) ... ar.tracker.wait() # wait for send to finish >>> view.track False """ # preflight: save flags, and set temporaries saved_flags = {} for f in self._flag_names: saved_flags[f] = getattr(self, f) self.set_flags(**kwargs) # yield to the with-statement block try: yield finally: # postflight: restore saved flags self.set_flags(**saved_flags) #---------------------------------------------------------------- # apply #---------------------------------------------------------------- def _sync_results(self): """to be called by @sync_results decorator after submitting any tasks. """ delta = self.outstanding.difference(self.client.outstanding) completed = self.outstanding.intersection(delta) self.outstanding = self.outstanding.difference(completed) @sync_results @save_ids def _really_apply(self, f, args, kwargs, block=None, **options): """wrapper for client.send_apply_request""" raise NotImplementedError("Implement in subclasses") def apply(self, f, *args, **kwargs): """calls ``f(*args, **kwargs)`` on remote engines, returning the result. This method sets all apply flags via this View's attributes. Returns :class:`~ipyparallel.client.asyncresult.AsyncResult` instance if ``self.block`` is False, otherwise the return value of ``f(*args, **kwargs)``. """ return self._really_apply(f, args, kwargs) def apply_async(self, f, *args, **kwargs): """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner. Returns :class:`~ipyparallel.client.asyncresult.AsyncResult` instance. """ return self._really_apply(f, args, kwargs, block=False) def apply_sync(self, f, *args, **kwargs): """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner, returning the result. """ return self._really_apply(f, args, kwargs, block=True) #---------------------------------------------------------------- # wrappers for client and control methods #---------------------------------------------------------------- @sync_results def spin(self): """spin the client, and sync""" self.client.spin() @sync_results def wait(self, jobs=None, timeout=-1): """waits on one or more `jobs`, for up to `timeout` seconds. Parameters ---------- jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects ints are indices to self.history strs are msg_ids default: wait on all outstanding messages timeout : float a time in seconds, after which to give up. default is -1, which means no timeout Returns ------- True : when all msg_ids are done False : timeout reached, some msg_ids still outstanding """ if jobs is None: jobs = self.history return self.client.wait(jobs, timeout) def abort(self, jobs=None, targets=None, block=None): """Abort jobs on my engines. Parameters ---------- jobs : None, str, list of strs, optional if None: abort all jobs. else: abort specific msg_id(s). """ block = block if block is not None else self.block targets = targets if targets is not None else self.targets jobs = jobs if jobs is not None else list(self.outstanding) return self.client.abort(jobs=jobs, targets=targets, block=block) def queue_status(self, targets=None, verbose=False): """Fetch the Queue status of my engines""" targets = targets if targets is not None else self.targets return self.client.queue_status(targets=targets, verbose=verbose) def purge_results(self, jobs=[], targets=[]): """Instruct the controller to forget specific results.""" if targets is None or targets == 'all': targets = self.targets return self.client.purge_results(jobs=jobs, targets=targets) def shutdown(self, targets=None, restart=False, hub=False, block=None): """Terminates one or more engine processes, optionally including the hub. """ block = self.block if block is None else block if targets is None or targets == 'all': targets = self.targets return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block) def get_result(self, indices_or_msg_ids=None, block=None, owner=False): """return one or more results, specified by history index or msg_id. See :meth:`ipyparallel.client.client.Client.get_result` for details. """ if indices_or_msg_ids is None: indices_or_msg_ids = -1 if isinstance(indices_or_msg_ids, int): indices_or_msg_ids = self.history[indices_or_msg_ids] elif isinstance(indices_or_msg_ids, (list, tuple, set)): indices_or_msg_ids = list(indices_or_msg_ids) for i, index in enumerate(indices_or_msg_ids): if isinstance(index, int): indices_or_msg_ids[i] = self.history[index] return self.client.get_result(indices_or_msg_ids, block=block, owner=owner) #------------------------------------------------------------------- # Map #------------------------------------------------------------------- @sync_results def map(self, f, *sequences, **kwargs): """override in subclasses""" raise NotImplementedError def map_async(self, f, *sequences, **kwargs): """Parallel version of builtin :func:`python:map`, using this view's engines. This is equivalent to ``map(...block=False)``. See `self.map` for details. """ if 'block' in kwargs: raise TypeError( "map_async doesn't take a `block` keyword argument.") kwargs['block'] = False return self.map(f, *sequences, **kwargs) def map_sync(self, f, *sequences, **kwargs): """Parallel version of builtin :func:`python:map`, using this view's engines. This is equivalent to ``map(...block=True)``. See `self.map` for details. """ if 'block' in kwargs: raise TypeError( "map_sync doesn't take a `block` keyword argument.") kwargs['block'] = True return self.map(f, *sequences, **kwargs) def imap(self, f, *sequences, **kwargs): """Parallel version of :func:`itertools.imap`. See `self.map` for details. """ return iter(self.map_async(f, *sequences, **kwargs)) #------------------------------------------------------------------- # Decorators #------------------------------------------------------------------- def remote(self, block=None, **flags): """Decorator for making a RemoteFunction""" block = self.block if block is None else block return remote(self, block=block, **flags) def parallel(self, dist='b', block=None, **flags): """Decorator for making a ParallelFunction""" block = self.block if block is None else block return parallel(self, dist=dist, block=block, **flags)
class Kernel(SingletonConfigurable): #--------------------------------------------------------------------------- # Kernel interface #--------------------------------------------------------------------------- # attribute to override with a GUI eventloop = Any(None) def _eventloop_changed(self, name, old, new): """schedule call to eventloop from IOLoop""" loop = ioloop.IOLoop.instance() loop.add_callback(self.enter_eventloop) session = Instance(Session, allow_none=True) profile_dir = Instance('IPython.core.profiledir.ProfileDir', allow_none=True) shell_streams = List() control_stream = Instance(ZMQStream, allow_none=True) iopub_socket = Instance(zmq.Socket, allow_none=True) stdin_socket = Instance(zmq.Socket, allow_none=True) log = Instance(logging.Logger, allow_none=True) # identities: int_id = Integer(-1) ident = Unicode() def _ident_default(self): return unicode_type(uuid.uuid4()) # This should be overridden by wrapper kernels that implement any real # language. language_info = {} # any links that should go in the help menu help_links = List() # Private interface _darwin_app_nap = Bool(True, config=True, help="""Whether to use appnope for compatiblity with OS X App Nap. Only affects OS X >= 10.9. """ ) # track associations with current request _allow_stdin = Bool(False) _parent_header = Dict() _parent_ident = Any(b'') # Time to sleep after flushing the stdout/err buffers in each execute # cycle. While this introduces a hard limit on the minimal latency of the # execute cycle, it helps prevent output synchronization problems for # clients. # Units are in seconds. The minimum zmq latency on local host is probably # ~150 microseconds, set this to 500us for now. We may need to increase it # a little if it's not enough after more interactive testing. _execute_sleep = Float(0.0005, config=True) # Frequency of the kernel's event loop. # Units are in seconds, kernel subclasses for GUI toolkits may need to # adapt to milliseconds. _poll_interval = Float(0.05, config=True) # If the shutdown was requested over the network, we leave here the # necessary reply message so it can be sent by our registered atexit # handler. This ensures that the reply is only sent to clients truly at # the end of our shutdown process (which happens after the underlying # IPython shell's own shutdown). _shutdown_message = None # This is a dict of port number that the kernel is listening on. It is set # by record_ports and used by connect_request. _recorded_ports = Dict() # set of aborted msg_ids aborted = Set() # Track execution count here. For IPython, we override this to use the # execution count we store in the shell. execution_count = 0 def __init__(self, **kwargs): super(Kernel, self).__init__(**kwargs) # Build dict of handlers for message types msg_types = [ 'execute_request', 'complete_request', 'inspect_request', 'history_request', 'comm_info_request', 'kernel_info_request', 'connect_request', 'shutdown_request', 'apply_request', 'is_complete_request', ] self.shell_handlers = {} for msg_type in msg_types: self.shell_handlers[msg_type] = getattr(self, msg_type) control_msg_types = msg_types + [ 'clear_request', 'abort_request' ] self.control_handlers = {} for msg_type in control_msg_types: self.control_handlers[msg_type] = getattr(self, msg_type) def dispatch_control(self, msg): """dispatch control requests""" idents,msg = self.session.feed_identities(msg, copy=False) try: msg = self.session.deserialize(msg, content=True, copy=False) except: self.log.error("Invalid Control Message", exc_info=True) return self.log.debug("Control received: %s", msg) # Set the parent message for side effects. self.set_parent(idents, msg) self._publish_status(u'busy') header = msg['header'] msg_type = header['msg_type'] handler = self.control_handlers.get(msg_type, None) if handler is None: self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type) else: try: handler(self.control_stream, idents, msg) except Exception: self.log.error("Exception in control handler:", exc_info=True) sys.stdout.flush() sys.stderr.flush() self._publish_status(u'idle') def dispatch_shell(self, stream, msg): """dispatch shell requests""" # flush control requests first if self.control_stream: self.control_stream.flush() idents,msg = self.session.feed_identities(msg, copy=False) try: msg = self.session.deserialize(msg, content=True, copy=False) except: self.log.error("Invalid Message", exc_info=True) return # Set the parent message for side effects. self.set_parent(idents, msg) self._publish_status(u'busy') header = msg['header'] msg_id = header['msg_id'] msg_type = msg['header']['msg_type'] # Print some info about this message and leave a '--->' marker, so it's # easier to trace visually the message chain when debugging. Each # handler prints its message at the end. self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type) self.log.debug(' Content: %s\n --->\n ', msg['content']) if msg_id in self.aborted: self.aborted.remove(msg_id) # is it safe to assume a msg_id will not be resubmitted? reply_type = msg_type.split('_')[0] + '_reply' status = {'status' : 'aborted'} md = {'engine' : self.ident} md.update(status) self.session.send(stream, reply_type, metadata=md, content=status, parent=msg, ident=idents) return handler = self.shell_handlers.get(msg_type, None) if handler is None: self.log.error("UNKNOWN MESSAGE TYPE: %r", msg_type) else: self.log.debug("%s: %s", msg_type, msg) self.pre_handler_hook() try: handler(stream, idents, msg) except Exception: self.log.error("Exception in message handler:", exc_info=True) finally: self.post_handler_hook() sys.stdout.flush() sys.stderr.flush() self._publish_status(u'idle') def pre_handler_hook(self): """Hook to execute before calling message handler""" # ensure default_int_handler during handler call self.saved_sigint_handler = signal(SIGINT, default_int_handler) def post_handler_hook(self): """Hook to execute after calling message handler""" signal(SIGINT, self.saved_sigint_handler) def enter_eventloop(self): """enter eventloop""" self.log.info("entering eventloop %s", self.eventloop) for stream in self.shell_streams: # flush any pending replies, # which may be skipped by entering the eventloop stream.flush(zmq.POLLOUT) # restore default_int_handler signal(SIGINT, default_int_handler) while self.eventloop is not None: try: self.eventloop(self) except KeyboardInterrupt: # Ctrl-C shouldn't crash the kernel self.log.error("KeyboardInterrupt caught in kernel") continue else: # eventloop exited cleanly, this means we should stop (right?) self.eventloop = None break self.log.info("exiting eventloop") def start(self): """register dispatchers for streams""" if self.control_stream: self.control_stream.on_recv(self.dispatch_control, copy=False) def make_dispatcher(stream): def dispatcher(msg): return self.dispatch_shell(stream, msg) return dispatcher for s in self.shell_streams: s.on_recv(make_dispatcher(s), copy=False) # publish idle status self._publish_status('starting') def do_one_iteration(self): """step eventloop just once""" if self.control_stream: self.control_stream.flush() for stream in self.shell_streams: # handle at most one request per iteration stream.flush(zmq.POLLIN, 1) stream.flush(zmq.POLLOUT) def record_ports(self, ports): """Record the ports that this kernel is using. The creator of the Kernel instance must call this methods if they want the :meth:`connect_request` method to return the port numbers. """ self._recorded_ports = ports #--------------------------------------------------------------------------- # Kernel request handlers #--------------------------------------------------------------------------- def _make_metadata(self, other=None): """init metadata dict, for execute/apply_reply""" new_md = { 'dependencies_met' : True, 'engine' : self.ident, 'started': datetime.now(), } if other: new_md.update(other) return new_md def _publish_execute_input(self, code, parent, execution_count): """Publish the code request on the iopub stream.""" self.session.send(self.iopub_socket, u'execute_input', {u'code':code, u'execution_count': execution_count}, parent=parent, ident=self._topic('execute_input') ) def _publish_status(self, status, parent=None): """send status (busy/idle) on IOPub""" self.session.send(self.iopub_socket, u'status', {u'execution_state': status}, parent=parent or self._parent_header, ident=self._topic('status'), ) def set_parent(self, ident, parent): """Set the current parent_header Side effects (IOPub messages) and replies are associated with the request that caused them via the parent_header. The parent identity is used to route input_request messages on the stdin channel. """ self._parent_ident = ident self._parent_header = parent def send_response(self, stream, msg_or_type, content=None, ident=None, buffers=None, track=False, header=None, metadata=None): """Send a response to the message we're currently processing. This accepts all the parameters of :meth:`jupyter_client.session.Session.send` except ``parent``. This relies on :meth:`set_parent` having been called for the current message. """ return self.session.send(stream, msg_or_type, content, self._parent_header, ident, buffers, track, header, metadata) def execute_request(self, stream, ident, parent): """handle an execute_request""" try: content = parent[u'content'] code = py3compat.cast_unicode_py2(content[u'code']) silent = content[u'silent'] store_history = content.get(u'store_history', not silent) user_expressions = content.get('user_expressions', {}) allow_stdin = content.get('allow_stdin', False) except: self.log.error("Got bad msg: ") self.log.error("%s", parent) return stop_on_error = content.get('stop_on_error', True) md = self._make_metadata(parent['metadata']) # Re-broadcast our input for the benefit of listening clients, and # start computing output if not silent: self.execution_count += 1 self._publish_execute_input(code, parent, self.execution_count) reply_content = self.do_execute(code, silent, store_history, user_expressions, allow_stdin) # Flush output before sending the reply. sys.stdout.flush() sys.stderr.flush() # FIXME: on rare occasions, the flush doesn't seem to make it to the # clients... This seems to mitigate the problem, but we definitely need # to better understand what's going on. if self._execute_sleep: time.sleep(self._execute_sleep) # Send the reply. reply_content = json_clean(reply_content) md['status'] = reply_content['status'] if reply_content['status'] == 'error' and \ reply_content['ename'] == 'UnmetDependency': md['dependencies_met'] = False reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent, metadata=md, ident=ident) self.log.debug("%s", reply_msg) if not silent and reply_msg['content']['status'] == u'error' and stop_on_error: self._abort_queues() def do_execute(self, code, silent, store_history=True, user_expressions=None, allow_stdin=False): """Execute user code. Must be overridden by subclasses. """ raise NotImplementedError def complete_request(self, stream, ident, parent): content = parent['content'] code = content['code'] cursor_pos = content['cursor_pos'] matches = self.do_complete(code, cursor_pos) matches = json_clean(matches) completion_msg = self.session.send(stream, 'complete_reply', matches, parent, ident) self.log.debug("%s", completion_msg) def do_complete(self, code, cursor_pos): """Override in subclasses to find completions. """ return {'matches' : [], 'cursor_end' : cursor_pos, 'cursor_start' : cursor_pos, 'metadata' : {}, 'status' : 'ok'} def inspect_request(self, stream, ident, parent): content = parent['content'] reply_content = self.do_inspect(content['code'], content['cursor_pos'], content.get('detail_level', 0)) # Before we send this object over, we scrub it for JSON usage reply_content = json_clean(reply_content) msg = self.session.send(stream, 'inspect_reply', reply_content, parent, ident) self.log.debug("%s", msg) def do_inspect(self, code, cursor_pos, detail_level=0): """Override in subclasses to allow introspection. """ return {'status': 'ok', 'data': {}, 'metadata': {}, 'found': False} def history_request(self, stream, ident, parent): content = parent['content'] reply_content = self.do_history(**content) reply_content = json_clean(reply_content) msg = self.session.send(stream, 'history_reply', reply_content, parent, ident) self.log.debug("%s", msg) def do_history(self, hist_access_type, output, raw, session=None, start=None, stop=None, n=None, pattern=None, unique=False): """Override in subclasses to access history. """ return {'history': []} def connect_request(self, stream, ident, parent): if self._recorded_ports is not None: content = self._recorded_ports.copy() else: content = {} msg = self.session.send(stream, 'connect_reply', content, parent, ident) self.log.debug("%s", msg) @property def kernel_info(self): return { 'protocol_version': kernel_protocol_version, 'implementation': self.implementation, 'implementation_version': self.implementation_version, 'language_info': self.language_info, 'banner': self.banner, 'help_links': self.help_links, } def kernel_info_request(self, stream, ident, parent): msg = self.session.send(stream, 'kernel_info_reply', self.kernel_info, parent, ident) self.log.debug("%s", msg) def comm_info_request(self, stream, ident, parent): content = parent['content'] target_name = content.get('target_name', None) # Should this be moved to ipkernel? if hasattr(self, 'comm_manager'): comms = { k: dict(target_name=v.target_name) for (k, v) in self.comm_manager.comms.items() if v.target_name == target_name or target_name is None } else: comms = {} reply_content = dict(comms=comms) msg = self.session.send(stream, 'comm_info_reply', reply_content, parent, ident) self.log.debug("%s", msg) def shutdown_request(self, stream, ident, parent): content = self.do_shutdown(parent['content']['restart']) self.session.send(stream, u'shutdown_reply', content, parent, ident=ident) # same content, but different msg_id for broadcasting on IOPub self._shutdown_message = self.session.msg(u'shutdown_reply', content, parent ) self._at_shutdown() # call sys.exit after a short delay loop = ioloop.IOLoop.instance() loop.add_timeout(time.time()+0.1, loop.stop) def do_shutdown(self, restart): """Override in subclasses to do things when the frontend shuts down the kernel. """ return {'status': 'ok', 'restart': restart} def is_complete_request(self, stream, ident, parent): content = parent['content'] code = content['code'] reply_content = self.do_is_complete(code) reply_content = json_clean(reply_content) reply_msg = self.session.send(stream, 'is_complete_reply', reply_content, parent, ident) self.log.debug("%s", reply_msg) def do_is_complete(self, code): """Override in subclasses to find completions. """ return {'status' : 'unknown', } #--------------------------------------------------------------------------- # Engine methods #--------------------------------------------------------------------------- def apply_request(self, stream, ident, parent): try: content = parent[u'content'] bufs = parent[u'buffers'] msg_id = parent['header']['msg_id'] except: self.log.error("Got bad msg: %s", parent, exc_info=True) return md = self._make_metadata(parent['metadata']) reply_content, result_buf = self.do_apply(content, bufs, msg_id, md) # put 'ok'/'error' status in header, for scheduler introspection: md['status'] = reply_content['status'] # flush i/o sys.stdout.flush() sys.stderr.flush() self.session.send(stream, u'apply_reply', reply_content, parent=parent, ident=ident,buffers=result_buf, metadata=md) def do_apply(self, content, bufs, msg_id, reply_metadata): """Override in subclasses to support the IPython parallel framework. """ raise NotImplementedError #--------------------------------------------------------------------------- # Control messages #--------------------------------------------------------------------------- def abort_request(self, stream, ident, parent): """abort a specific msg by id""" msg_ids = parent['content'].get('msg_ids', None) if isinstance(msg_ids, string_types): msg_ids = [msg_ids] if not msg_ids: self._abort_queues() for mid in msg_ids: self.aborted.add(str(mid)) content = dict(status='ok') reply_msg = self.session.send(stream, 'abort_reply', content=content, parent=parent, ident=ident) self.log.debug("%s", reply_msg) def clear_request(self, stream, idents, parent): """Clear our namespace.""" content = self.do_clear() self.session.send(stream, 'clear_reply', ident=idents, parent=parent, content = content) def do_clear(self): """Override in subclasses to clear the namespace This is only required for IPython.parallel. """ raise NotImplementedError #--------------------------------------------------------------------------- # Protected interface #--------------------------------------------------------------------------- def _topic(self, topic): """prefixed topic for IOPub messages""" if self.int_id >= 0: base = "engine.%i" % self.int_id else: base = "kernel.%s" % self.ident return py3compat.cast_bytes("%s.%s" % (base, topic)) def _abort_queues(self): for stream in self.shell_streams: if stream: self._abort_queue(stream) def _abort_queue(self, stream): poller = zmq.Poller() poller.register(stream.socket, zmq.POLLIN) while True: idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True) if msg is None: return self.log.info("Aborting:") self.log.info("%s", msg) msg_type = msg['header']['msg_type'] reply_type = msg_type.split('_')[0] + '_reply' status = {'status' : 'aborted'} md = {'engine' : self.ident} md.update(status) reply_msg = self.session.send(stream, reply_type, metadata=md, content=status, parent=msg, ident=idents) self.log.debug("%s", reply_msg) # We need to wait a bit for requests to come in. This can probably # be set shorter for true asynchronous clients. poller.poll(50) def _no_raw_input(self): """Raise StdinNotImplentedError if active frontend doesn't support stdin.""" raise StdinNotImplementedError("raw_input was called, but this " "frontend does not support stdin.") def getpass(self, prompt=''): """Forward getpass to frontends Raises ------ StdinNotImplentedError if active frontend doesn't support stdin. """ if not self._allow_stdin: raise StdinNotImplementedError( "getpass was called, but this frontend does not support input requests." ) return self._input_request(prompt, self._parent_ident, self._parent_header, password=True, ) def raw_input(self, prompt=''): """Forward raw_input to frontends Raises ------ StdinNotImplentedError if active frontend doesn't support stdin. """ if not self._allow_stdin: raise StdinNotImplementedError( "raw_input was called, but this frontend does not support input requests." ) return self._input_request(prompt, self._parent_ident, self._parent_header, password=False, ) def _input_request(self, prompt, ident, parent, password=False): # Flush output before making the request. sys.stderr.flush() sys.stdout.flush() # flush the stdin socket, to purge stale replies while True: try: self.stdin_socket.recv_multipart(zmq.NOBLOCK) except zmq.ZMQError as e: if e.errno == zmq.EAGAIN: break else: raise # Send the input request. content = json_clean(dict(prompt=prompt, password=password)) self.session.send(self.stdin_socket, u'input_request', content, parent, ident=ident) # Await a response. while True: try: ident, reply = self.session.recv(self.stdin_socket, 0) except Exception: self.log.warn("Invalid Message:", exc_info=True) except KeyboardInterrupt: # re-raise KeyboardInterrupt, to truncate traceback raise KeyboardInterrupt else: break try: value = py3compat.unicode_to_str(reply['content']['value']) except: self.log.error("Bad input_reply: %s", parent) value = '' if value == '\x04': # EOF raise EOFError return value def _at_shutdown(self): """Actions taken at shutdown by the kernel, called by python's atexit. """ # io.rprint("Kernel at_shutdown") # dbg if self._shutdown_message is not None: self.session.send(self.iopub_socket, self._shutdown_message, ident=self._topic('shutdown')) self.log.debug("%s", self._shutdown_message) [ s.flush(zmq.POLLOUT) for s in self.shell_streams ]
class GitHubOAuthenticator(OAuthenticator): login_service = "GitHub" # deprecated names github_client_id = Unicode(config=True, help="DEPRECATED") def _github_client_id_changed(self, name, old, new): self.log.warn("github_client_id is deprecated, use client_id") self.client_id = new github_client_secret = Unicode(config=True, help="DEPRECATED") def _github_client_secret_changed(self, name, old, new): self.log.warn("github_client_secret is deprecated, use client_secret") self.client_secret = new client_id_env = 'GITHUB_CLIENT_ID' client_secret_env = 'GITHUB_CLIENT_SECRET' login_handler = GitHubLoginHandler github_organization_whitelist = Set( config=True, help="Automatically whitelist members of selected organizations", ) @gen.coroutine def authenticate(self, handler, data=None): code = handler.get_argument("code") # TODO: Configure the curl_httpclient for tornado http_client = AsyncHTTPClient() # Exchange the OAuth code for a GitHub Access Token # # See: https://developer.github.com/v3/oauth/ # GitHub specifies a POST request yet requires URL parameters params = dict(client_id=self.client_id, client_secret=self.client_secret, code=code) url = url_concat("https://%s/login/oauth/access_token" % GITHUB_HOST, params) req = HTTPRequest( url, method="POST", headers={"Accept": "application/json"}, body='', # Body is required for a POST... validate_cert=False) resp = yield http_client.fetch(req) resp_json = json.loads(resp.body.decode('utf8', 'replace')) access_token = resp_json['access_token'] # Determine who the logged in user is req = HTTPRequest("https://%s/user" % GITHUB_API, method="GET", headers=_api_headers(access_token), validate_cert=False) resp = yield http_client.fetch(req) resp_json = json.loads(resp.body.decode('utf8', 'replace')) username = resp_json["login"] # Check if user is a member of any whitelisted organizations. # This check is performed here, as it requires `access_token`. if self.github_organization_whitelist: for org in self.github_organization_whitelist: user_in_org = yield self._check_organization_whitelist( org, username, access_token) if user_in_org: return username else: # User not found in member list for any organisation return None else: # no organization whitelisting return username @gen.coroutine def _check_organization_whitelist(self, org, username, access_token): http_client = AsyncHTTPClient() headers = _api_headers(access_token) # Get all the members for organization 'org' next_page = "https://%s/orgs/%s/members" % (GITHUB_API, org) while next_page: req = HTTPRequest(next_page, method="GET", headers=headers) resp = yield http_client.fetch(req) resp_json = json.loads(resp.body.decode('utf8', 'replace')) next_page = next_page_from_links(resp) org_members = set(entry["login"] for entry in resp_json) # check if any of the organizations seen thus far are in whitelist if username in org_members: return True return False
class SanitizeHTML(Preprocessor): # Bleach config. attributes = Any( config=True, default_value=ALLOWED_ATTRIBUTES, help="Allowed HTML tag attributes", ) tags = List( Unicode(), config=True, default_value=ALLOWED_TAGS, help="List of HTML tags to allow", ) styles = List( Unicode(), config=True, default_value=ALLOWED_STYLES, help="Allowed CSS styles if <style> tag is allowed", ) strip = Bool( config=True, default_value=False, help="If True, remove unsafe markup entirely instead of escaping", ) strip_comments = Bool( config=True, default_value=True, help="If True, strip comments from escaped HTML", ) # Display data config. safe_output_keys = Set( config=True, default_value={ "metadata", # Not a mimetype per-se, but expected and safe. "text/plain", "text/latex", "application/json", "image/png", "image/jpeg", }, help="Cell output mimetypes to render without modification", ) sanitized_output_types = Set( config=True, default_value={ "text/html", "text/markdown", }, help="Cell output types to display after escaping with Bleach.", ) def preprocess_cell(self, cell, resources, cell_index): """ Sanitize potentially-dangerous contents of the cell. Cell Types: raw: Sanitize literal HTML markdown: Sanitize literal HTML code: Sanitize outputs that could result in code execution """ if cell.cell_type == "raw": # Sanitize all raw cells anyway. # Only ones with the text/html mimetype should be emitted # but erring on the side of safety maybe. cell.source = self.sanitize_html_tags(cell.source) return cell, resources elif cell.cell_type == "markdown": cell.source = self.sanitize_html_tags(cell.source) return cell, resources elif cell.cell_type == "code": cell.outputs = self.sanitize_code_outputs(cell.outputs) return cell, resources def sanitize_code_outputs(self, outputs): """ Sanitize code cell outputs. Removes 'text/javascript' fields from display_data outputs, and runs `sanitize_html_tags` over 'text/html'. """ for output in outputs: # These are always ascii, so nothing to escape. if output["output_type"] in ("stream", "error"): continue data = output.data to_remove = [] for key in data: if key in self.safe_output_keys: continue elif key in self.sanitized_output_types: self.log.info("Sanitizing %s" % key) data[key] = self.sanitize_html_tags(data[key]) else: # Mark key for removal. (Python doesn't allow deletion of # keys from a dict during iteration) to_remove.append(key) for key in to_remove: self.log.info("Removing %s" % key) del data[key] return outputs def sanitize_html_tags(self, html_str): """ Sanitize a string containing raw HTML tags. """ kwargs = dict( tags=self.tags, attributes=self.attributes, strip=self.strip, strip_comments=self.strip_comments, ) if _USE_BLEACH_CSS_SANITIZER: css_sanitizer = CSSSanitizer(allowed_css_properties=self.styles) kwargs.update(css_sanitizer=css_sanitizer) elif _USE_BLEACH_STYLES: kwargs.update(styles=self.styles) return clean(html_str, **kwargs)
class Authenticator(LoggingConfigurable): """Base class for implementing an authentication provider for JupyterHub""" db = Any() admin_users = Set(help=""" Set of users that will have admin rights on this JupyterHub. Admin users have extra privileges: - Use the admin panel to see list of users logged in - Add / remove users in some authenticators - Restart / halt the hub - Start / stop users' single-user servers - Can access each individual users' single-user server (if configured) Admin access should be treated the same way root access is. Defaults to an empty set, in which case no user has admin access. """).tag(config=True) whitelist = Set(help=""" Whitelist of usernames that are allowed to log in. Use this with supported authenticators to restrict which users can log in. This is an additional whitelist that further restricts users, beyond whatever restrictions the authenticator has in place. If empty, does not perform any additional restriction. """).tag(config=True) @observe('whitelist') def _check_whitelist(self, change): short_names = [name for name in change['new'] if len(name) <= 1] if short_names: sorted_names = sorted(short_names) single = ''.join(sorted_names) string_set_typo = "set('%s')" % single self.log.warning( "whitelist contains single-character names: %s; did you mean set([%r]) instead of %s?", sorted_names[:8], single, string_set_typo, ) custom_html = Unicode(help=""" HTML form to be overridden by authenticators if they want a custom authentication form. Defaults to an empty string, which shows the default username/password form. """) login_service = Unicode(help=""" Name of the login service that this authenticator is providing using to authenticate users. Example: GitHub, MediaWiki, Google, etc. Setting this value replaces the login form with a "Login with <login_service>" button. Any authenticator that redirects to an external service (e.g. using OAuth) should set this. """) username_pattern = Unicode(help=""" Regular expression pattern that all valid usernames must match. If a username does not match the pattern specified here, authentication will not be attempted. If not set, allow any username. """).tag(config=True) @observe('username_pattern') def _username_pattern_changed(self, change): if not change['new']: self.username_regex = None self.username_regex = re.compile(change['new']) username_regex = Any(help=""" Compiled regex kept in sync with `username_pattern` """) def validate_username(self, username): """Validate a normalized username Return True if username is valid, False otherwise. """ if not self.username_regex: return True return bool(self.username_regex.match(username)) username_map = Dict( help="""Dictionary mapping authenticator usernames to JupyterHub users. Primarily used to normalize OAuth user names to local users. """).tag(config=True) delete_invalid_users = Bool( False, help="""Delete any users from the database that do not pass validation When JupyterHub starts, `.add_user` will be called on each user in the database to verify that all users are still valid. If `delete_invalid_users` is True, any users that do not pass validation will be deleted from the database. Use this if users might be deleted from an external system, such as local user accounts. If False (default), invalid users remain in the Hub's database and a warning will be issued. This is the default to avoid data loss due to config changes. """) def normalize_username(self, username): """Normalize the given username and return it Override in subclasses if usernames need different normalization rules. The default attempts to lowercase the username and apply `username_map` if it is set. """ username = username.lower() username = self.username_map.get(username, username) return username def check_whitelist(self, username): """Check if a username is allowed to authenticate based on whitelist configuration Return True if username is allowed, False otherwise. No whitelist means any username is allowed. Names are normalized *before* being checked against the whitelist. """ if not self.whitelist: # No whitelist means any name is allowed return True return username in self.whitelist @gen.coroutine def get_authenticated_user(self, handler, data): """Authenticate the user who is attempting to log in Returns normalized username if successful, None otherwise. This calls `authenticate`, which should be overridden in subclasses, normalizes the username if any normalization should be done, and then validates the name in the whitelist. This is the outer API for authenticating a user. Subclasses should not need to override this method. The various stages can be overridden separately: - `authenticate` turns formdata into a username - `normalize_username` normalizes the username - `check_whitelist` checks against the user whitelist """ username = yield self.authenticate(handler, data) if username is None: return username = self.normalize_username(username) if not self.validate_username(username): self.log.warning("Disallowing invalid username %r.", username) return whitelist_pass = yield gen.maybe_future(self.check_whitelist(username)) if whitelist_pass: return username else: self.log.warning("User %r not in whitelist.", username) return @gen.coroutine def authenticate(self, handler, data): """Authenticate a user with login form data This must be a tornado gen.coroutine. It must return the username on successful authentication, and return None on failed authentication. Checking the whitelist is handled separately by the caller. Args: handler (tornado.web.RequestHandler): the current request handler data (dict): The formdata of the login form. The default form has 'username' and 'password' fields. Returns: username (str or None): The username of the authenticated user, or None if Authentication failed """ def pre_spawn_start(self, user, spawner): """Hook called before spawning a user's server Can be used to do auth-related startup, e.g. opening PAM sessions. """ def post_spawn_stop(self, user, spawner): """Hook called after stopping a user container Can be used to do auth-related cleanup, e.g. closing PAM sessions. """ def add_user(self, user): """Hook called when a user is added to JupyterHub This is called: - When a user first authenticates - When the hub restarts, for all users. This method may be a coroutine. By default, this just adds the user to the whitelist. Subclasses may do more extensive things, such as adding actual unix users, but they should call super to ensure the whitelist is updated. Note that this should be idempotent, since it is called whenever the hub restarts for all users. Args: user (User): The User wrapper object """ if not self.validate_username(user.name): raise ValueError("Invalid username: %s" % user.name) if self.whitelist: self.whitelist.add(user.name) def delete_user(self, user): """Hook called when a user is deleted Removes the user from the whitelist. Subclasses should call super to ensure the whitelist is updated. Args: user (User): The User wrapper object """ self.whitelist.discard(user.name) auto_login = Bool(False, config=True, help="""Automatically begin the login process rather than starting with a "Login with..." link at `/hub/login` To work, `.login_url()` must give a URL other than the default `/hub/login`, such as an oauth handler or another automatic login handler, registered with `.get_handlers()`. .. versionadded:: 0.8 """) def login_url(self, base_url): """Override this when registering a custom login handler Generally used by authenticators that do not use simple form-based authentication. The subclass overriding this is responsible for making sure there is a handler available to handle the URL returned from this method, using the `get_handlers` method. Args: base_url (str): the base URL of the Hub (e.g. /hub/) Returns: str: The login URL, e.g. '/hub/login' """ return url_path_join(base_url, 'login') def logout_url(self, base_url): """Override when registering a custom logout handler The subclass overriding this is responsible for making sure there is a handler available to handle the URL returned from this method, using the `get_handlers` method. Args: base_url (str): the base URL of the Hub (e.g. /hub/) Returns: str: The logout URL, e.g. '/hub/logout' """ return url_path_join(base_url, 'logout') def get_handlers(self, app): """Return any custom handlers the authenticator needs to register Used in conjugation with `login_url` and `logout_url`. Args: app (JupyterHub Application): the application object, in case it needs to be accessed for info. Returns: handlers (list): list of ``('/url', Handler)`` tuples passed to tornado. The Hub prefix is added to any URLs. """ return [ ('/login', LoginHandler), ]
class Widget(LoggingConfigurable): #------------------------------------------------------------------------- # Class attributes #------------------------------------------------------------------------- _widget_construction_callback = None widgets = {} widget_types = {} @staticmethod def on_widget_constructed(callback): """Registers a callback to be called when a widget is constructed. The callback must have the following signature: callback(widget)""" Widget._widget_construction_callback = callback @staticmethod def _call_widget_constructed(widget): """Static method, called when a widget is constructed.""" if Widget._widget_construction_callback is not None and callable(Widget._widget_construction_callback): Widget._widget_construction_callback(widget) @staticmethod def handle_comm_opened(comm, msg): """Static method, called when a widget is constructed.""" class_name = str(msg['content']['data']['widget_class']) if class_name in Widget.widget_types: widget_class = Widget.widget_types[class_name] else: widget_class = import_item(class_name) widget = widget_class(comm=comm) #------------------------------------------------------------------------- # Traits #------------------------------------------------------------------------- _model_module = Unicode('jupyter-js-widgets', help="""A requirejs module name in which to find _model_name. If empty, look in the global registry.""").tag(sync=True) _model_name = Unicode('WidgetModel', help="""Name of the backbone model registered in the front-end to create and sync this widget with.""").tag(sync=True) _view_module = Unicode(None, allow_none=True, help="""A requirejs module in which to find _view_name. If empty, look in the global registry.""").tag(sync=True) _view_name = Unicode(None, allow_none=True, help="""Default view registered in the front-end to use to represent the widget.""").tag(sync=True) comm = Instance('ipykernel.comm.Comm', allow_none=True) msg_throttle = Int(3, help="""Maximum number of msgs the front-end can send before receiving an idle msg from the back-end.""").tag(sync=True) keys = List() def _keys_default(self): return [name for name in self.traits(sync=True)] _property_lock = Dict() _holding_sync = False _states_to_send = Set() _display_callbacks = Instance(CallbackDispatcher, ()) _msg_callbacks = Instance(CallbackDispatcher, ()) #------------------------------------------------------------------------- # (Con/de)structor #------------------------------------------------------------------------- def __init__(self, **kwargs): """Public constructor""" self._model_id = kwargs.pop('model_id', None) super(Widget, self).__init__(**kwargs) Widget._call_widget_constructed(self) self.open() def __del__(self): """Object disposal""" self.close() #------------------------------------------------------------------------- # Properties #------------------------------------------------------------------------- def open(self): """Open a comm to the frontend if one isn't already open.""" if self.comm is None: state, buffer_keys, buffers = self._split_state_buffers(self.get_state()) args = dict(target_name='jupyter.widget', data=state) if self._model_id is not None: args['comm_id'] = self._model_id self.comm = Comm(**args) if buffers: # FIXME: workaround ipykernel missing binary message support in open-on-init # send state with binary elements as second message self.send_state() def _comm_changed(self, name, new): """Called when the comm is changed.""" if new is None: return self._model_id = self.model_id self.comm.on_msg(self._handle_msg) Widget.widgets[self.model_id] = self @property def model_id(self): """Gets the model id of this widget. If a Comm doesn't exist yet, a Comm will be created automagically.""" return self.comm.comm_id #------------------------------------------------------------------------- # Methods #------------------------------------------------------------------------- def close(self): """Close method. Closes the underlying comm. When the comm is closed, all of the widget views are automatically removed from the front-end.""" if self.comm is not None: Widget.widgets.pop(self.model_id, None) self.comm.close() self.comm = None self._ipython_display_ = None def _split_state_buffers(self, state): """Return (state_without_buffers, buffer_keys, buffers) for binary message parts""" buffer_keys, buffers = [], [] for k, v in list(state.items()): if isinstance(v, _binary_types): state.pop(k) buffers.append(v) buffer_keys.append(k) return state, buffer_keys, buffers def send_state(self, key=None): """Sends the widget state, or a piece of it, to the front-end. Parameters ---------- key : unicode, or iterable (optional) A single property's name or iterable of property names to sync with the front-end. """ state = self.get_state(key=key) state, buffer_keys, buffers = self._split_state_buffers(state) msg = {'method': 'update', 'state': state, 'buffers': buffer_keys} self._send(msg, buffers=buffers) def get_state(self, key=None): """Gets the widget state, or a piece of it. Parameters ---------- key : unicode or iterable (optional) A single property's name or iterable of property names to get. Returns ------- state : dict of states metadata : dict metadata for each field: {key: metadata} """ if key is None: keys = self.keys elif isinstance(key, string_types): keys = [key] elif isinstance(key, collections.Iterable): keys = key else: raise ValueError("key must be a string, an iterable of keys, or None") state = {} traits = self.traits() if not PY3 else {} # no need to construct traits on PY3 for k in keys: to_json = self.trait_metadata(k, 'to_json', self._trait_to_json) value = to_json(getattr(self, k), self) if not PY3 and isinstance(traits[k], Bytes) and isinstance(value, bytes): value = memoryview(value) state[k] = value return state def set_state(self, sync_data): """Called when a state is received from the front-end.""" # The order of these context managers is important. Properties must # be locked when the hold_trait_notification context manager is # released and notifications are fired. with self._lock_property(**sync_data), self.hold_trait_notifications(): for name in sync_data: if name in self.keys: from_json = self.trait_metadata(name, 'from_json', self._trait_from_json) self.set_trait(name, from_json(sync_data[name], self)) def send(self, content, buffers=None): """Sends a custom msg to the widget model in the front-end. Parameters ---------- content : dict Content of the message to send. buffers : list of binary buffers Binary buffers to send with message """ self._send({"method": "custom", "content": content}, buffers=buffers) def on_msg(self, callback, remove=False): """(Un)Register a custom msg receive callback. Parameters ---------- callback: callable callback will be passed three arguments when a message arrives:: callback(widget, content, buffers) remove: bool True if the callback should be unregistered.""" self._msg_callbacks.register_callback(callback, remove=remove) def on_displayed(self, callback, remove=False): """(Un)Register a widget displayed callback. Parameters ---------- callback: method handler Must have a signature of:: callback(widget, **kwargs) kwargs from display are passed through without modification. remove: bool True if the callback should be unregistered.""" self._display_callbacks.register_callback(callback, remove=remove) def add_traits(self, **traits): """Dynamically add trait attributes to the Widget.""" super(Widget, self).add_traits(**traits) for name, trait in traits.items(): if trait.get_metadata('sync'): self.keys.append(name) self.send_state(name) def notify_change(self, change): """Called when a property has changed.""" # Send the state before the user registered callbacks for trait changes # have all fired. name = change['name'] if self.comm is not None and name in self.keys: # Make sure this isn't information that the front-end just sent us. if self._should_send_property(name, change['new']): # Send new state to front-end self.send_state(key=name) LoggingConfigurable.notify_change(self, change) #------------------------------------------------------------------------- # Support methods #------------------------------------------------------------------------- @contextmanager def _lock_property(self, **properties): """Lock a property-value pair. The value should be the JSON state of the property. NOTE: This, in addition to the single lock for all state changes, is flawed. In the future we may want to look into buffering state changes back to the front-end.""" self._property_lock = properties try: yield finally: self._property_lock = {} @contextmanager def hold_sync(self): """Hold syncing any state until the outermost context manager exits""" if self._holding_sync is True: yield else: try: self._holding_sync = True yield finally: self._holding_sync = False self.send_state(self._states_to_send) self._states_to_send.clear() def _should_send_property(self, key, value): """Check the property lock (property_lock)""" to_json = self.trait_metadata(key, 'to_json', self._trait_to_json) if (key in self._property_lock and to_json(value, self) == self._property_lock[key]): return False elif self._holding_sync: self._states_to_send.add(key) return False else: return True # Event handlers @_show_traceback def _handle_msg(self, msg): """Called when a msg is received from the front-end""" data = msg['content']['data'] method = data['method'] # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one. if method == 'backbone': if 'sync_data' in data: # get binary buffers too sync_data = data['sync_data'] for i,k in enumerate(data.get('buffer_keys', [])): sync_data[k] = msg['buffers'][i] self.set_state(sync_data) # handles all methods # Handle a state request. elif method == 'request_state': self.send_state() # Handle a custom msg from the front-end. elif method == 'custom': if 'content' in data: self._handle_custom_msg(data['content'], msg['buffers']) # Catch remainder. else: self.log.error('Unknown front-end to back-end widget msg with method "%s"' % method) def _handle_custom_msg(self, content, buffers): """Called when a custom msg is received.""" self._msg_callbacks(self, content, buffers) def _handle_displayed(self, **kwargs): """Called when a view has been displayed for this widget instance""" self._display_callbacks(self, **kwargs) @staticmethod def _trait_to_json(x, self): """Convert a trait value to json.""" return x @staticmethod def _trait_from_json(x, self): """Convert json values to objects.""" return x def _ipython_display_(self, **kwargs): """Called when `IPython.display.display` is called on the widget.""" def loud_error(message): self.log.warn(message) sys.stderr.write('%s\n' % message) # Show view. if self._view_name is not None: validated = Widget._version_validated # Before the user tries to display a widget. Validate that the # widget front-end is what is expected. if validated is None: loud_error('Widget Javascript not detected. It may not be installed properly.') elif not validated: loud_error('The installed widget Javascript is the wrong version.') # TODO: delete this sending of a comm message when the display statement # below works. Then add a 'text/plain' mimetype to the dictionary below. self._send({"method": "display"}) # The 'application/vnd.jupyter.widget' mimetype has not been registered yet. # See the registration process and naming convention at # http://tools.ietf.org/html/rfc6838 # and the currently registered mimetypes at # http://www.iana.org/assignments/media-types/media-types.xhtml. # We don't have a 'text/plain' entry so that the display message will be # will be invisible in the current notebook. data = { 'application/vnd.jupyter.widget': self._model_id } display(data, raw=True) self._handle_displayed(**kwargs) def _send(self, msg, buffers=None): """Sends a message to the model in the front-end.""" self.comm.send(data=msg, buffers=buffers)
class LocalAuthenticator(Authenticator): """Base class for Authenticators that work with local Linux/UNIX users Checks for local users, and can attempt to create them if they exist. """ create_system_users = Bool(False, help=""" If set to True, will attempt to create local system users if they do not exist already. Supports Linux and BSD variants only. """).tag(config=True) add_user_cmd = Command(help=""" The command to use for creating users as a list of strings For each element in the list, the string USERNAME will be replaced with the user's username. The username will also be appended as the final argument. For Linux, the default value is: ['adduser', '-q', '--gecos', '""', '--disabled-password'] To specify a custom home directory, set this to: ['adduser', '-q', '--gecos', '""', '--home', '/customhome/USERNAME', '--disabled-password'] This will run the command: adduser -q --gecos "" --home /customhome/river --disabled-password river when the user 'river' is created. """).tag(config=True) @default('add_user_cmd') def _add_user_cmd_default(self): """Guess the most likely-to-work adduser command for each platform""" if sys.platform == 'darwin': raise ValueError("I don't know how to create users on OS X") elif which('pw'): # Probably BSD return ['pw', 'useradd', '-m'] elif sys.platform == 'linux': # return ['adduser', '-m', '-s', '/bin/bash'] return ['adduser'] else: # This appears to be the Linux non-interactive adduser command: return ['adduser', '-q', '--gecos', '""', '--disabled-password'] group_whitelist = Set(help=""" Whitelist all users from this UNIX group. This makes the username whitelist ineffective. """).tag(config=True) @observe('group_whitelist') def _group_whitelist_changed(self, change): """ Log a warning if both group_whitelist and user whitelist are set. """ if self.whitelist: self.log.warning( "Ignoring username whitelist because group whitelist supplied!" ) def check_whitelist(self, username): if self.group_whitelist: return self.check_group_whitelist(username) else: return super().check_whitelist(username) def check_group_whitelist(self, username): """ If group_whitelist is configured, check if authenticating user is part of group. """ if not self.group_whitelist: return False for grnam in self.group_whitelist: try: group = getgrnam(grnam) except KeyError: self.log.error('No such group: [%s]' % grnam) continue if username in group.gr_mem: return True return False @gen.coroutine def add_user(self, user): """Hook called whenever a new user is added If self.create_system_users, the user will attempt to be created if it doesn't exist. """ user_exists = yield gen.maybe_future(self.system_user_exists(user)) if not user_exists: if self.create_system_users: yield gen.maybe_future(self.add_system_user(user)) else: raise KeyError("User %s does not exist." % user.name) yield gen.maybe_future(super().add_user(user)) @staticmethod def system_user_exists(user): """Check if the user exists on the system""" try: pwd.getpwnam(user.name) except KeyError: return False else: return True def add_system_user(self, user): """Create a new local UNIX user on the system. Tested to work on FreeBSD and Linux, at least. """ name = user.name pwd = crypt.crypt('deeplearn', 'jion') cmd = [arg.replace('USERNAME', name) for arg in self.add_user_cmd] + ['-p'] + [pwd] + [name] self.log.info("Creating user: %s", ' '.join(map(pipes.quote, cmd))) print("Creating user: "******"Failed to create system user %s: %s" % (name, err))