Example #1
0
class GitLabOAuthenticator(OAuthenticator):

    login_service = "GitLab"

    client_id_env = 'GITLAB_CLIENT_ID'
    client_secret_env = 'GITLAB_CLIENT_SECRET'
    login_handler = GitLabLoginHandler

    gitlab_group_whitelist = Set(
        config=True,
        help="Automatically whitelist members of selected groups",
    )
    gitlab_project_id_whitelist = Set(
        config=True,
        help=
        "Automatically whitelist members with Developer access to selected project ids",
    )

    @gen.coroutine
    def authenticate(self, handler, data=None):
        code = handler.get_argument("code")
        # TODO: Configure the curl_httpclient for tornado
        http_client = AsyncHTTPClient()

        # Exchange the OAuth code for a GitLab Access Token
        #
        # See: https://github.com/gitlabhq/gitlabhq/blob/master/doc/api/oauth2.md

        # GitLab specifies a POST request yet requires URL parameters
        params = dict(
            client_id=self.client_id,
            client_secret=self.client_secret,
            code=code,
            grant_type="authorization_code",
            redirect_uri=self.get_callback_url(handler),
        )

        validate_server_cert = self.validate_server_cert

        url = url_concat("%s/oauth/token" % GITLAB_URL, params)

        req = HTTPRequest(
            url,
            method="POST",
            headers={"Accept": "application/json"},
            validate_cert=validate_server_cert,
            body=''  # Body is required for a POST...
        )

        resp = yield http_client.fetch(req)
        resp_json = json.loads(resp.body.decode('utf8', 'replace'))

        access_token = resp_json['access_token']

        # Determine who the logged in user is
        req = HTTPRequest("%s/user" % GITLAB_API,
                          method="GET",
                          validate_cert=validate_server_cert,
                          headers=_api_headers(access_token))
        resp = yield http_client.fetch(req)
        resp_json = json.loads(resp.body.decode('utf8', 'replace'))

        username = resp_json["username"]
        user_id = resp_json["id"]
        is_admin = resp_json.get("is_admin", False)

        # Check if user is a member of any whitelisted groups or projects.
        # These checks are performed here, as it requires `access_token`.
        user_in_group = user_in_project = False
        is_group_specified = is_project_id_specified = False

        if self.gitlab_group_whitelist:
            is_group_specified = True
            user_in_group = yield self._check_group_whitelist(
                user_id, access_token)

        # We skip project_id check if user is in whitelisted group.
        if self.gitlab_project_id_whitelist and not user_in_group:
            is_project_id_specified = True
            user_in_project = yield self._check_project_id_whitelist(
                user_id, access_token)

        no_config_specified = not (is_group_specified
                                   or is_project_id_specified)

        if (is_group_specified and user_in_group) or \
            (is_project_id_specified and user_in_project) or \
                no_config_specified:
            return {
                'name': username,
                'auth_state': {
                    'access_token': access_token,
                    'gitlab_user': resp_json,
                }
            }
        else:
            self.log.warning("%s not in group or project whitelist", username)
            return None

    @gen.coroutine
    def _check_group_whitelist(self, user_id, access_token):
        http_client = AsyncHTTPClient()
        headers = _api_headers(access_token)
        # Check if user is a member of any group in the whitelist
        for group in map(url_escape, self.gitlab_group_whitelist):
            url = "%s/groups/%s/members/%d" % (GITLAB_API, group, user_id)
            req = HTTPRequest(url, method="GET", headers=headers)
            resp = yield http_client.fetch(req, raise_error=False)
            if resp.code == 200:
                return True  # user _is_ in group
        return False

    @gen.coroutine
    def _check_project_id_whitelist(self, user_id, access_token):
        http_client = AsyncHTTPClient()
        headers = _api_headers(access_token)
        # Check if user has developer access to any project in the whitelist
        for project in self.gitlab_project_id_whitelist:
            url = "%s/projects/%s/members/%d" % (GITLAB_API, project, user_id)
            req = HTTPRequest(url, method="GET", headers=headers)
            resp = yield http_client.fetch(req, raise_error=False)

            if resp.body:
                resp_json = json.loads(resp.body.decode('utf8', 'replace'))
                access_level = resp_json.get('access_level', 0)

                # We only allow access level Developer and above
                # Reference: https://docs.gitlab.com/ee/api/members.html
                if resp.code == 200 and access_level >= 30:
                    return True
        return False
Example #2
0
class GitHubOAuthenticator(OAuthenticator):

    # see github_scopes.md for details about scope config
    # set scopes via config, e.g.
    # c.GitHubOAuthenticator.scope = ['read:org']

    _deprecated_aliases = {
        "github_organization_whitelist": ("allowed_organizations", "0.12.0"),
    }

    @observe(*list(_deprecated_aliases))
    def _deprecated_trait(self, change):
        super()._deprecated_trait(change)

    login_service = "GitHub"

    github_url = Unicode("https://github.com", config=True)

    @default("github_url")
    def _github_url_default(self):
        github_url = os.environ.get("GITHUB_URL")
        if not github_url:
            # fallback on older GITHUB_HOST config,
            # treated the same as GITHUB_URL
            host = os.environ.get("GITHUB_HOST")
            if host:
                if os.environ.get("GITHUB_HTTP"):
                    protocol = "http"
                    warnings.warn(
                        'Use of GITHUB_HOST with GITHUB_HTTP might be deprecated in the future. '
                        'Use GITHUB_URL=http://{} to set host and protocol together.'
                        .format(host),
                        PendingDeprecationWarning,
                    )
                else:
                    protocol = "https"
                github_url = "{}://{}".format(protocol, host)

        if github_url:
            if '://' not in github_url:
                # ensure protocol is included, assume https if missing
                github_url = 'https://' + github_url

            return github_url
        else:
            # nothing specified, this is the true default
            github_url = "https://github.com"

        # ensure no trailing slash
        return github_url.rstrip("/")

    github_api = Unicode("https://api.github.com", config=True)

    @default("github_api")
    def _github_api_default(self):
        if self.github_url == "https://github.com":
            return "https://api.github.com"
        else:
            return self.github_url + "/api/v3"

    @default("authorize_url")
    def _authorize_url_default(self):
        return "%s/login/oauth/authorize" % (self.github_url)

    @default("token_url")
    def _token_url_default(self):
        return "%s/login/oauth/access_token" % (self.github_url)

    # deprecated names
    github_client_id = Unicode(config=True, help="DEPRECATED")

    def _github_client_id_changed(self, name, old, new):
        self.log.warning("github_client_id is deprecated, use client_id")
        self.client_id = new

    github_client_secret = Unicode(config=True, help="DEPRECATED")

    def _github_client_secret_changed(self, name, old, new):
        self.log.warning(
            "github_client_secret is deprecated, use client_secret")
        self.client_secret = new

    client_id_env = 'GITHUB_CLIENT_ID'
    client_secret_env = 'GITHUB_CLIENT_SECRET'

    github_organization_whitelist = Set(
        help="Deprecated, use `GitHubOAuthenticator.allowed_organizations`",
        config=True,
    )

    allowed_organizations = Set(
        config=True,
        help="Automatically allow members of selected organizations")

    async def authenticate(self, handler, data=None):
        """We set up auth_state based on additional GitHub info if we
        receive it.
        """
        code = handler.get_argument("code")
        # TODO: Configure the curl_httpclient for tornado
        http_client = AsyncHTTPClient()

        # Exchange the OAuth code for a GitHub Access Token
        #
        # See: https://developer.github.com/v3/oauth/

        # GitHub specifies a POST request yet requires URL parameters
        params = dict(client_id=self.client_id,
                      client_secret=self.client_secret,
                      code=code)

        url = url_concat(self.token_url, params)

        req = HTTPRequest(
            url,
            method="POST",
            headers={"Accept": "application/json"},
            body='',  # Body is required for a POST...
            validate_cert=self.validate_server_cert,
        )

        resp = await http_client.fetch(req)
        resp_json = json.loads(resp.body.decode('utf8', 'replace'))

        if 'access_token' in resp_json:
            access_token = resp_json['access_token']
        elif 'error_description' in resp_json:
            raise HTTPError(
                403,
                "An access token was not returned: {}".format(
                    resp_json['error_description']),
            )
        else:
            raise HTTPError(500, "Bad response: {}".format(resp))

        # Determine who the logged in user is
        req = HTTPRequest(
            self.github_api + "/user",
            method="GET",
            headers=_api_headers(access_token),
            validate_cert=self.validate_server_cert,
        )
        resp = await http_client.fetch(req)
        resp_json = json.loads(resp.body.decode('utf8', 'replace'))

        username = resp_json["login"]
        # username is now the GitHub userid.
        if not username:
            return None
        # Check if user is a member of any allowed organizations.
        # This check is performed here, as it requires `access_token`.
        if self.allowed_organizations:
            for org in self.allowed_organizations:
                user_in_org = await self._check_membership_allowed_organizations(
                    org, username, access_token)
                if user_in_org:
                    break
            else:  # User not found in member list for any organisation
                self.log.warning("User %s is not in allowed org list",
                                 username)
                return None
        userdict = {"name": username}
        # Now we set up auth_state
        userdict["auth_state"] = auth_state = {}
        # Save the access token and full GitHub reply (name, id, email) in auth state
        # These can be used for user provisioning in the Lab/Notebook environment.
        # e.g.
        #  1) stash the access token
        #  2) use the GitHub ID as the id
        #  3) set up name/email for .gitconfig
        auth_state['access_token'] = access_token
        # store the whole user model in auth_state.github_user
        auth_state['github_user'] = resp_json
        # A public email will return in the initial query (assuming default scope).
        # Private will not.

        return userdict

    async def _check_membership_allowed_organizations(self, org, username,
                                                      access_token):
        http_client = AsyncHTTPClient()
        headers = _api_headers(access_token)
        # Check membership of user `username` for organization `org` via api [check-membership](https://developer.github.com/v3/orgs/members/#check-membership)
        # With empty scope (even if authenticated by an org member), this
        #  will only await public org members.  You want 'read:org' in order
        #  to be able to iterate through all members.
        check_membership_url = "%s/orgs/%s/members/%s" % (
            self.github_api,
            org,
            username,
        )
        req = HTTPRequest(
            check_membership_url,
            method="GET",
            headers=headers,
            validate_cert=self.validate_server_cert,
        )
        self.log.debug("Checking GitHub organization membership: %s in %s?",
                       username, org)
        resp = await http_client.fetch(req, raise_error=False)
        print(resp)
        if resp.code == 204:
            self.log.info("Allowing %s as member of %s", username, org)
            return True
        else:
            try:
                resp_json = json.loads((resp.body
                                        or b'').decode('utf8', 'replace'))
                message = resp_json.get('message', '')
            except ValueError:
                message = ''
            self.log.debug(
                "%s does not appear to be a member of %s (status=%s): %s",
                username,
                org,
                resp.code,
                message,
            )
        return False
Example #3
0
class KernelSpecManager(LoggingConfigurable):

    kernel_spec_class = Type(
        KernelSpec,
        config=True,
        help="""The kernel spec class.  This is configurable to allow
        subclassing of the KernelSpecManager for customized behavior.
        """,
    )

    ensure_native_kernel = Bool(
        True,
        config=True,
        help="""If there is no Python kernelspec registered and the IPython
        kernel is available, ensure it is added to the spec list.
        """,
    )

    data_dir = Unicode()

    def _data_dir_default(self):
        return jupyter_data_dir()

    user_kernel_dir = Unicode()

    def _user_kernel_dir_default(self):
        return pjoin(self.data_dir, "kernels")

    whitelist = Set(
        config=True,
        help="""Deprecated, use `KernelSpecManager.allowed_kernelspecs`
        """,
    )
    allowed_kernelspecs = Set(
        config=True,
        help="""List of allowed kernel names.

        By default, all installed kernels are allowed.
        """,
    )
    kernel_dirs = List(
        help="List of kernel directories to search. Later ones take priority over earlier."
    )

    _deprecated_aliases = {
        "whitelist": ("allowed_kernelspecs", "7.0"),
    }

    # Method copied from
    # https://github.com/jupyterhub/jupyterhub/blob/d1a85e53dccfc7b1dd81b0c1985d158cc6b61820/jupyterhub/auth.py#L143-L161
    @observe(*list(_deprecated_aliases))
    def _deprecated_trait(self, change):
        """observer for deprecated traits"""
        old_attr = change.name
        new_attr, version = self._deprecated_aliases.get(old_attr)
        new_value = getattr(self, new_attr)
        if new_value != change.new:
            # only warn if different
            # protects backward-compatible config from warnings
            # if they set the same value under both names
            self.log.warning(
                (
                    "{cls}.{old} is deprecated in jupyter_client "
                    "{version}, use {cls}.{new} instead"
                ).format(
                    cls=self.__class__.__name__,
                    old=old_attr,
                    new=new_attr,
                    version=version,
                )
            )
            setattr(self, new_attr, change.new)

    def _kernel_dirs_default(self):
        dirs = jupyter_path("kernels")
        # At some point, we should stop adding .ipython/kernels to the path,
        # but the cost to keeping it is very small.
        try:
            from IPython.paths import get_ipython_dir  # type: ignore
        except ImportError:
            try:
                from IPython.utils.path import get_ipython_dir  # type: ignore
            except ImportError:
                # no IPython, no ipython dir
                get_ipython_dir = None
        if get_ipython_dir is not None:
            dirs.append(os.path.join(get_ipython_dir(), "kernels"))
        return dirs

    def find_kernel_specs(self):
        """Returns a dict mapping kernel names to resource directories."""
        d = {}
        for kernel_dir in self.kernel_dirs:
            kernels = _list_kernels_in(kernel_dir)
            for kname, spec in kernels.items():
                if kname not in d:
                    self.log.debug("Found kernel %s in %s", kname, kernel_dir)
                    d[kname] = spec

        if self.ensure_native_kernel and NATIVE_KERNEL_NAME not in d:
            try:
                from ipykernel.kernelspec import RESOURCES  # type: ignore

                self.log.debug(
                    "Native kernel (%s) available from %s",
                    NATIVE_KERNEL_NAME,
                    RESOURCES,
                )
                d[NATIVE_KERNEL_NAME] = RESOURCES
            except ImportError:
                self.log.warning("Native kernel (%s) is not available", NATIVE_KERNEL_NAME)

        if self.allowed_kernelspecs:
            # filter if there's an allow list
            d = {name: spec for name, spec in d.items() if name in self.allowed_kernelspecs}
        return d
        # TODO: Caching?

    def _get_kernel_spec_by_name(self, kernel_name, resource_dir):
        """Returns a :class:`KernelSpec` instance for a given kernel_name
        and resource_dir.
        """
        kspec = None
        if kernel_name == NATIVE_KERNEL_NAME:
            try:
                from ipykernel.kernelspec import RESOURCES, get_kernel_dict
            except ImportError:
                # It should be impossible to reach this, but let's play it safe
                pass
            else:
                if resource_dir == RESOURCES:
                    kspec = self.kernel_spec_class(resource_dir=resource_dir, **get_kernel_dict())
        if not kspec:
            kspec = self.kernel_spec_class.from_resource_dir(resource_dir)

        if not KPF.instance(parent=self.parent).is_provisioner_available(kspec):
            raise NoSuchKernel(kernel_name)

        return kspec

    def _find_spec_directory(self, kernel_name):
        """Find the resource directory of a named kernel spec"""
        for kernel_dir in [kd for kd in self.kernel_dirs if os.path.isdir(kd)]:
            files = os.listdir(kernel_dir)
            for f in files:
                path = pjoin(kernel_dir, f)
                if f.lower() == kernel_name and _is_kernel_dir(path):
                    return path

        if kernel_name == NATIVE_KERNEL_NAME:
            try:
                from ipykernel.kernelspec import RESOURCES
            except ImportError:
                pass
            else:
                return RESOURCES

    def get_kernel_spec(self, kernel_name):
        """Returns a :class:`KernelSpec` instance for the given kernel_name.

        Raises :exc:`NoSuchKernel` if the given kernel name is not found.
        """
        if not _is_valid_kernel_name(kernel_name):
            self.log.warning(
                f"Kernelspec name {kernel_name} is invalid: {_kernel_name_description}"
            )

        resource_dir = self._find_spec_directory(kernel_name.lower())
        if resource_dir is None:
            self.log.warning(f"Kernelspec name {kernel_name} cannot be found!")
            raise NoSuchKernel(kernel_name)

        return self._get_kernel_spec_by_name(kernel_name, resource_dir)

    def get_all_specs(self):
        """Returns a dict mapping kernel names to kernelspecs.

        Returns a dict of the form::

            {
              'kernel_name': {
                'resource_dir': '/path/to/kernel_name',
                'spec': {"the spec itself": ...}
              },
              ...
            }
        """
        d = self.find_kernel_specs()
        res = {}
        for kname, resource_dir in d.items():
            try:
                if self.__class__ is KernelSpecManager:
                    spec = self._get_kernel_spec_by_name(kname, resource_dir)
                else:
                    # avoid calling private methods in subclasses,
                    # which may have overridden find_kernel_specs
                    # and get_kernel_spec, but not the newer get_all_specs
                    spec = self.get_kernel_spec(kname)

                res[kname] = {"resource_dir": resource_dir, "spec": spec.to_dict()}
            except NoSuchKernel:
                pass  # The appropriate warning has already been logged
            except Exception:
                self.log.warning("Error loading kernelspec %r", kname, exc_info=True)
        return res

    def remove_kernel_spec(self, name):
        """Remove a kernel spec directory by name.

        Returns the path that was deleted.
        """
        save_native = self.ensure_native_kernel
        try:
            self.ensure_native_kernel = False
            specs = self.find_kernel_specs()
        finally:
            self.ensure_native_kernel = save_native
        spec_dir = specs[name]
        self.log.debug("Removing %s", spec_dir)
        if os.path.islink(spec_dir):
            os.remove(spec_dir)
        else:
            shutil.rmtree(spec_dir)
        return spec_dir

    def _get_destination_dir(self, kernel_name, user=False, prefix=None):
        if user:
            return os.path.join(self.user_kernel_dir, kernel_name)
        elif prefix:
            return os.path.join(os.path.abspath(prefix), "share", "jupyter", "kernels", kernel_name)
        else:
            return os.path.join(SYSTEM_JUPYTER_PATH[0], "kernels", kernel_name)

    def install_kernel_spec(
        self, source_dir, kernel_name=None, user=False, replace=None, prefix=None
    ):
        """Install a kernel spec by copying its directory.

        If ``kernel_name`` is not given, the basename of ``source_dir`` will
        be used.

        If ``user`` is False, it will attempt to install into the systemwide
        kernel registry. If the process does not have appropriate permissions,
        an :exc:`OSError` will be raised.

        If ``prefix`` is given, the kernelspec will be installed to
        PREFIX/share/jupyter/kernels/KERNEL_NAME. This can be sys.prefix
        for installation inside virtual or conda envs.
        """
        source_dir = source_dir.rstrip("/\\")
        if not kernel_name:
            kernel_name = os.path.basename(source_dir)
        kernel_name = kernel_name.lower()
        if not _is_valid_kernel_name(kernel_name):
            raise ValueError(
                "Invalid kernel name %r.  %s" % (kernel_name, _kernel_name_description)
            )

        if user and prefix:
            raise ValueError("Can't specify both user and prefix. Please choose one or the other.")

        if replace is not None:
            warnings.warn(
                "replace is ignored. Installing a kernelspec always replaces an existing "
                "installation",
                DeprecationWarning,
                stacklevel=2,
            )

        destination = self._get_destination_dir(kernel_name, user=user, prefix=prefix)
        self.log.debug("Installing kernelspec in %s", destination)

        kernel_dir = os.path.dirname(destination)
        if kernel_dir not in self.kernel_dirs:
            self.log.warning(
                "Installing to %s, which is not in %s. The kernelspec may not be found.",
                kernel_dir,
                self.kernel_dirs,
            )

        if os.path.isdir(destination):
            self.log.info("Removing existing kernelspec in %s", destination)
            shutil.rmtree(destination)

        shutil.copytree(source_dir, destination)
        self.log.info("Installed kernelspec %s in %s", kernel_name, destination)
        return destination

    def install_native_kernel_spec(self, user=False):
        """DEPRECATED: Use ipykernel.kernelspec.install"""
        warnings.warn(
            "install_native_kernel_spec is deprecated. Use ipykernel.kernelspec import install.",
            stacklevel=2,
        )
        from ipykernel.kernelspec import install

        install(self, user=user)
Example #4
0
class JupyterHub(Application):
    """An Application for starting a Multi-User Jupyter Notebook server."""
    name = 'jupyterhub'
    version = jupyterhub.__version__

    description = """Start a multi-user Jupyter Notebook server
    
    Spawns a configurable-http-proxy and multi-user Hub,
    which authenticates users and spawns single-user Notebook servers
    on behalf of users.
    """

    examples = """
    
    generate default config file:
    
        jupyterhub --generate-config -f /etc/jupyterhub/jupyterhub.py
    
    spawn the server on 10.0.1.2:443 with https:
    
        jupyterhub --ip 10.0.1.2 --port 443 --ssl-key my_ssl.key --ssl-cert my_ssl.cert
    """

    aliases = Dict(aliases)
    flags = Dict(flags)

    subcommands = {'token': (NewToken, "Generate an API token for a user")}

    classes = List([
        Spawner,
        LocalProcessSpawner,
        Authenticator,
        PAMAuthenticator,
    ])

    config_file = Unicode(
        'jupyterhub_config.py',
        config=True,
        help="The config file to load",
    )
    generate_config = Bool(
        False,
        config=True,
        help="Generate default config file",
    )
    answer_yes = Bool(
        False,
        config=True,
        help="Answer yes to any questions (e.g. confirm overwrite)")
    pid_file = Unicode('',
                       config=True,
                       help="""File to write PID
        Useful for daemonizing jupyterhub.
        """)
    cookie_max_age_days = Float(
        14,
        config=True,
        help="""Number of days for a login cookie to be valid.
        Default is two weeks.
        """)
    last_activity_interval = Integer(
        300,
        config=True,
        help=
        "Interval (in seconds) at which to update last-activity timestamps.")
    proxy_check_interval = Integer(
        30,
        config=True,
        help="Interval (in seconds) at which to check if the proxy is running."
    )

    data_files_path = Unicode(
        DATA_FILES_PATH,
        config=True,
        help=
        "The location of jupyterhub data files (e.g. /usr/local/share/jupyter/hub)"
    )

    template_paths = List(
        config=True,
        help="Paths to search for jinja templates.",
    )

    def _template_paths_default(self):
        return [os.path.join(self.data_files_path, 'templates')]

    ssl_key = Unicode(
        '',
        config=True,
        help="""Path to SSL key file for the public facing interface of the proxy
        
        Use with ssl_cert
        """)
    ssl_cert = Unicode(
        '',
        config=True,
        help=
        """Path to SSL certificate file for the public facing interface of the proxy
        
        Use with ssl_key
        """)
    ip = Unicode('', config=True, help="The public facing ip of the proxy")
    port = Integer(8000,
                   config=True,
                   help="The public facing port of the proxy")
    base_url = URLPrefix('/',
                         config=True,
                         help="The base URL of the entire application")

    jinja_environment_options = Dict(
        config=True,
        help="Supply extra arguments that will be passed to Jinja environment."
    )

    proxy_cmd = Command('configurable-http-proxy',
                        config=True,
                        help="""The command to start the http proxy.
        
        Only override if configurable-http-proxy is not on your PATH
        """)
    debug_proxy = Bool(False,
                       config=True,
                       help="show debug output in configurable-http-proxy")
    proxy_auth_token = Unicode(config=True,
                               help="""The Proxy Auth token.

        Loaded from the CONFIGPROXY_AUTH_TOKEN env variable by default.
        """)

    def _proxy_auth_token_default(self):
        token = os.environ.get('CONFIGPROXY_AUTH_TOKEN', None)
        if not token:
            self.log.warn('\n'.join([
                "",
                "Generating CONFIGPROXY_AUTH_TOKEN. Restarting the Hub will require restarting the proxy.",
                "Set CONFIGPROXY_AUTH_TOKEN env or JupyterHub.proxy_auth_token config to avoid this message.",
                "",
            ]))
            token = orm.new_token()
        return token

    proxy_api_ip = Unicode('localhost',
                           config=True,
                           help="The ip for the proxy API handlers")
    proxy_api_port = Integer(config=True,
                             help="The port for the proxy API handlers")

    def _proxy_api_port_default(self):
        return self.port + 1

    hub_port = Integer(8081, config=True, help="The port for this process")
    hub_ip = Unicode('localhost', config=True, help="The ip for this process")

    hub_prefix = URLPrefix(
        '/hub/',
        config=True,
        help="The prefix for the hub server. Must not be '/'")

    def _hub_prefix_default(self):
        return url_path_join(self.base_url, '/hub/')

    def _hub_prefix_changed(self, name, old, new):
        if new == '/':
            raise TraitError("'/' is not a valid hub prefix")
        if not new.startswith(self.base_url):
            self.hub_prefix = url_path_join(self.base_url, new)

    cookie_secret = Bytes(config=True,
                          env='JPY_COOKIE_SECRET',
                          help="""The cookie secret to use to encrypt cookies.

        Loaded from the JPY_COOKIE_SECRET env variable by default.
        """)

    cookie_secret_file = Unicode(
        'jupyterhub_cookie_secret',
        config=True,
        help="""File in which to store the cookie secret.""")

    authenticator_class = Type(PAMAuthenticator,
                               Authenticator,
                               config=True,
                               help="""Class for authenticating users.
        
        This should be a class with the following form:
        
        - constructor takes one kwarg: `config`, the IPython config object.
        
        - is a tornado.gen.coroutine
        - returns username on success, None on failure
        - takes two arguments: (handler, data),
          where `handler` is the calling web.RequestHandler,
          and `data` is the POST form data from the login page.
        """)

    authenticator = Instance(Authenticator)

    def _authenticator_default(self):
        return self.authenticator_class(parent=self, db=self.db)

    # class for spawning single-user servers
    spawner_class = Type(
        LocalProcessSpawner,
        Spawner,
        config=True,
        help="""The class to use for spawning single-user servers.
        
        Should be a subclass of Spawner.
        """)

    db_url = Unicode(
        'sqlite:///jupyterhub.sqlite',
        config=True,
        help="url for the database. e.g. `sqlite:///jupyterhub.sqlite`")

    def _db_url_changed(self, name, old, new):
        if '://' not in new:
            # assume sqlite, if given as a plain filename
            self.db_url = 'sqlite:///%s' % new

    db_kwargs = Dict(
        config=True,
        help="""Include any kwargs to pass to the database connection.
        See sqlalchemy.create_engine for details.
        """)

    reset_db = Bool(False, config=True, help="Purge and reset the database.")
    debug_db = Bool(
        False,
        config=True,
        help="log all database transactions. This has A LOT of output")
    db = Any()
    session_factory = Any()

    admin_access = Bool(
        False,
        config=True,
        help="""Grant admin users permission to access single-user servers.
        
        Users should be properly informed if this is enabled.
        """)
    admin_users = Set(
        config=True,
        help="""DEPRECATED, use Authenticator.admin_users instead.""")

    tornado_settings = Dict(config=True)

    cleanup_servers = Bool(
        True,
        config=True,
        help="""Whether to shutdown single-user servers when the Hub shuts down.
        
        Disable if you want to be able to teardown the Hub while leaving the single-user servers running.
        
        If both this and cleanup_proxy are False, sending SIGINT to the Hub will
        only shutdown the Hub, leaving everything else running.
        
        The Hub should be able to resume from database state.
        """)

    cleanup_proxy = Bool(
        True,
        config=True,
        help="""Whether to shutdown the proxy when the Hub shuts down.
        
        Disable if you want to be able to teardown the Hub while leaving the proxy running.
        
        Only valid if the proxy was starting by the Hub process.
        
        If both this and cleanup_servers are False, sending SIGINT to the Hub will
        only shutdown the Hub, leaving everything else running.
        
        The Hub should be able to resume from database state.
        """)

    handlers = List()

    _log_formatter_cls = CoroutineLogFormatter
    http_server = None
    proxy_process = None
    io_loop = None

    def _log_level_default(self):
        return logging.INFO

    def _log_datefmt_default(self):
        """Exclude date from default date format"""
        return "%Y-%m-%d %H:%M:%S"

    def _log_format_default(self):
        """override default log format to include time"""
        return "%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s %(module)s:%(lineno)d]%(end_color)s %(message)s"

    extra_log_file = Unicode("",
                             config=True,
                             help="Set a logging.FileHandler on this file.")
    extra_log_handlers = List(
        Instance(logging.Handler),
        config=True,
        help="Extra log handlers to set on JupyterHub logger",
    )

    def init_logging(self):
        # This prevents double log messages because tornado use a root logger that
        # self.log is a child of. The logging module dipatches log messages to a log
        # and all of its ancenstors until propagate is set to False.
        self.log.propagate = False

        if self.extra_log_file:
            self.extra_log_handlers.append(
                logging.FileHandler(self.extra_log_file))

        _formatter = self._log_formatter_cls(
            fmt=self.log_format,
            datefmt=self.log_datefmt,
        )
        for handler in self.extra_log_handlers:
            if handler.formatter is None:
                handler.setFormatter(_formatter)
            self.log.addHandler(handler)

        # hook up tornado 3's loggers to our app handlers
        for log in (app_log, access_log, gen_log):
            # ensure all log statements identify the application they come from
            log.name = self.log.name
        logger = logging.getLogger('tornado')
        logger.propagate = True
        logger.parent = self.log
        logger.setLevel(self.log.level)

    def init_ports(self):
        if self.hub_port == self.port:
            raise TraitError(
                "The hub and proxy cannot both listen on port %i" % self.port)
        if self.hub_port == self.proxy_api_port:
            raise TraitError(
                "The hub and proxy API cannot both listen on port %i" %
                self.hub_port)
        if self.proxy_api_port == self.port:
            raise TraitError(
                "The proxy's public and API ports cannot both be %i" %
                self.port)

    @staticmethod
    def add_url_prefix(prefix, handlers):
        """add a url prefix to handlers"""
        for i, tup in enumerate(handlers):
            lis = list(tup)
            lis[0] = url_path_join(prefix, tup[0])
            handlers[i] = tuple(lis)
        return handlers

    def init_handlers(self):
        h = []
        h.extend(handlers.default_handlers)
        h.extend(apihandlers.default_handlers)
        # load handlers from the authenticator
        h.extend(self.authenticator.get_handlers(self))

        self.handlers = self.add_url_prefix(self.hub_prefix, h)

        # some extra handlers, outside hub_prefix
        self.handlers.extend([
            (r"%s" % self.hub_prefix.rstrip('/'), web.RedirectHandler, {
                "url": self.hub_prefix,
                "permanent": False,
            }),
            (r"(?!%s).*" % self.hub_prefix, handlers.PrefixRedirectHandler),
            (r'(.*)', handlers.Template404),
        ])

    def _check_db_path(self, path):
        """More informative log messages for failed filesystem access"""
        path = os.path.abspath(path)
        parent, fname = os.path.split(path)
        user = getuser()
        if not os.path.isdir(parent):
            self.log.error("Directory %s does not exist", parent)
        if os.path.exists(parent) and not os.access(parent, os.W_OK):
            self.log.error("%s cannot create files in %s", user, parent)
        if os.path.exists(path) and not os.access(path, os.W_OK):
            self.log.error("%s cannot edit %s", user, path)

    def init_secrets(self):
        trait_name = 'cookie_secret'
        trait = self.traits()[trait_name]
        env_name = trait.get_metadata('env')
        secret_file = os.path.abspath(
            os.path.expanduser(self.cookie_secret_file))
        secret = self.cookie_secret
        secret_from = 'config'
        # load priority: 1. config, 2. env, 3. file
        if not secret and os.environ.get(env_name):
            secret_from = 'env'
            self.log.info("Loading %s from env[%s]", trait_name, env_name)
            secret = binascii.a2b_hex(os.environ[env_name])
        if not secret and os.path.exists(secret_file):
            secret_from = 'file'
            perm = os.stat(secret_file).st_mode
            if perm & 0o077:
                self.log.error("Bad permissions on %s", secret_file)
            else:
                self.log.info("Loading %s from %s", trait_name, secret_file)
                with open(secret_file) as f:
                    b64_secret = f.read()
                try:
                    secret = binascii.a2b_base64(b64_secret)
                except Exception as e:
                    self.log.error("%s does not contain b64 key: %s",
                                   secret_file, e)
        if not secret:
            secret_from = 'new'
            self.log.debug("Generating new %s", trait_name)
            secret = os.urandom(SECRET_BYTES)

        if secret_file and secret_from == 'new':
            # if we generated a new secret, store it in the secret_file
            self.log.info("Writing %s to %s", trait_name, secret_file)
            b64_secret = binascii.b2a_base64(secret).decode('ascii')
            with open(secret_file, 'w') as f:
                f.write(b64_secret)
            try:
                os.chmod(secret_file, 0o600)
            except OSError:
                self.log.warn("Failed to set permissions on %s", secret_file)
        # store the loaded trait value
        self.cookie_secret = secret

    def init_db(self):
        """Create the database connection"""
        self.log.debug("Connecting to db: %s", self.db_url)
        try:
            self.session_factory = orm.new_session_factory(self.db_url,
                                                           reset=self.reset_db,
                                                           echo=self.debug_db,
                                                           **self.db_kwargs)
            self.db = scoped_session(self.session_factory)()
        except OperationalError as e:
            self.log.error("Failed to connect to db: %s", self.db_url)
            self.log.debug("Database error was:", exc_info=True)
            if self.db_url.startswith('sqlite:///'):
                self._check_db_path(self.db_url.split(':///', 1)[1])
            self.exit(1)

    def init_hub(self):
        """Load the Hub config into the database"""
        self.hub = self.db.query(orm.Hub).first()
        if self.hub is None:
            self.hub = orm.Hub(server=orm.Server(
                ip=self.hub_ip,
                port=self.hub_port,
                base_url=self.hub_prefix,
                cookie_name='jupyter-hub-token',
            ))
            self.db.add(self.hub)
        else:
            server = self.hub.server
            server.ip = self.hub_ip
            server.port = self.hub_port
            server.base_url = self.hub_prefix

        self.db.commit()

    @gen.coroutine
    def init_users(self):
        """Load users into and from the database"""
        db = self.db

        if self.admin_users and not self.authenticator.admin_users:
            self.log.warn("\nJupyterHub.admin_users is deprecated."
                          "\nUse Authenticator.admin_users instead.")
            self.authenticator.admin_users = self.admin_users
        admin_users = self.authenticator.admin_users

        if not admin_users:
            # add current user as admin if there aren't any others
            admins = db.query(orm.User).filter(orm.User.admin == True)
            if admins.first() is None:
                admin_users.add(getuser())

        new_users = []

        for name in admin_users:
            # ensure anyone specified as admin in config is admin in db
            user = orm.User.find(db, name)
            if user is None:
                user = orm.User(name=name, admin=True)
                new_users.append(user)
                db.add(user)
            else:
                user.admin = True

        # the admin_users config variable will never be used after this point.
        # only the database values will be referenced.

        whitelist = self.authenticator.whitelist

        if not whitelist:
            self.log.info(
                "Not using whitelist. Any authenticated user will be allowed.")

        # add whitelisted users to the db
        for name in whitelist:
            user = orm.User.find(db, name)
            if user is None:
                user = orm.User(name=name)
                new_users.append(user)
                db.add(user)

        if whitelist:
            # fill the whitelist with any users loaded from the db,
            # so we are consistent in both directions.
            # This lets whitelist be used to set up initial list,
            # but changes to the whitelist can occur in the database,
            # and persist across sessions.
            for user in db.query(orm.User):
                whitelist.add(user.name)

        # The whitelist set and the users in the db are now the same.
        # From this point on, any user changes should be done simultaneously
        # to the whitelist set and user db, unless the whitelist is empty (all users allowed).

        db.commit()

        for user in new_users:
            yield gen.maybe_future(self.authenticator.add_user(user))
        db.commit()

    @gen.coroutine
    def init_spawners(self):
        db = self.db

        user_summaries = ['']

        def _user_summary(user):
            parts = ['{0: >8}'.format(user.name)]
            if user.admin:
                parts.append('admin')
            if user.server:
                parts.append('running at %s' % user.server)
            return ' '.join(parts)

        @gen.coroutine
        def user_stopped(user):
            status = yield user.spawner.poll()
            self.log.warn(
                "User %s server stopped with exit code: %s",
                user.name,
                status,
            )
            yield self.proxy.delete_user(user)
            yield user.stop()

        for user in db.query(orm.User):
            if not user.state:
                # without spawner state, server isn't valid
                user.server = None
                user_summaries.append(_user_summary(user))
                continue
            self.log.debug("Loading state for %s from db", user.name)
            user.spawner = spawner = self.spawner_class(
                user=user,
                hub=self.hub,
                config=self.config,
                db=self.db,
            )
            status = yield spawner.poll()
            if status is None:
                self.log.info("%s still running", user.name)
                spawner.add_poll_callback(user_stopped, user)
                spawner.start_polling()
            else:
                # user not running. This is expected if server is None,
                # but indicates the user's server died while the Hub wasn't running
                # if user.server is defined.
                log = self.log.warn if user.server else self.log.debug
                log("%s not running.", user.name)
                user.server = None

            user_summaries.append(_user_summary(user))

        self.log.debug("Loaded users: %s", '\n'.join(user_summaries))
        db.commit()

    def init_proxy(self):
        """Load the Proxy config into the database"""
        self.proxy = self.db.query(orm.Proxy).first()
        if self.proxy is None:
            self.proxy = orm.Proxy(
                public_server=orm.Server(),
                api_server=orm.Server(),
            )
            self.db.add(self.proxy)
            self.db.commit()
        self.proxy.auth_token = self.proxy_auth_token  # not persisted
        self.proxy.log = self.log
        self.proxy.public_server.ip = self.ip
        self.proxy.public_server.port = self.port
        self.proxy.api_server.ip = self.proxy_api_ip
        self.proxy.api_server.port = self.proxy_api_port
        self.proxy.api_server.base_url = '/api/routes/'
        self.db.commit()

    @gen.coroutine
    def start_proxy(self):
        """Actually start the configurable-http-proxy"""
        # check for proxy
        if self.proxy.public_server.is_up() or self.proxy.api_server.is_up():
            # check for *authenticated* access to the proxy (auth token can change)
            try:
                yield self.proxy.get_routes()
            except (HTTPError, OSError, socket.error) as e:
                if isinstance(e, HTTPError) and e.code == 403:
                    msg = "Did CONFIGPROXY_AUTH_TOKEN change?"
                else:
                    msg = "Is something else using %s?" % self.proxy.public_server.bind_url
                self.log.error(
                    "Proxy appears to be running at %s, but I can't access it (%s)\n%s",
                    self.proxy.public_server.bind_url, e, msg)
                self.exit(1)
                return
            else:
                self.log.info("Proxy already running at: %s",
                              self.proxy.public_server.bind_url)
            self.proxy_process = None
            return

        env = os.environ.copy()
        env['CONFIGPROXY_AUTH_TOKEN'] = self.proxy.auth_token
        cmd = self.proxy_cmd + [
            '--ip',
            self.proxy.public_server.ip,
            '--port',
            str(self.proxy.public_server.port),
            '--api-ip',
            self.proxy.api_server.ip,
            '--api-port',
            str(self.proxy.api_server.port),
            '--default-target',
            self.hub.server.host,
        ]
        if self.debug_proxy:
            cmd.extend(['--log-level', 'debug'])
        if self.ssl_key:
            cmd.extend(['--ssl-key', self.ssl_key])
        if self.ssl_cert:
            cmd.extend(['--ssl-cert', self.ssl_cert])
        self.log.info("Starting proxy @ %s", self.proxy.public_server.bind_url)
        self.log.debug("Proxy cmd: %s", cmd)
        try:
            self.proxy_process = Popen(cmd, env=env)
        except FileNotFoundError as e:
            self.log.error(
                "Failed to find proxy %r\n"
                "The proxy can be installed with `npm install -g configurable-http-proxy`"
                % self.proxy_cmd)
            self.exit(1)

        def _check():
            status = self.proxy_process.poll()
            if status is not None:
                e = RuntimeError("Proxy failed to start with exit code %i" %
                                 status)
                # py2-compatible `raise e from None`
                e.__cause__ = None
                raise e

        for server in (self.proxy.public_server, self.proxy.api_server):
            for i in range(10):
                _check()
                try:
                    yield server.wait_up(1)
                except TimeoutError:
                    continue
                else:
                    break
            yield server.wait_up(1)
        self.log.debug("Proxy started and appears to be up")

    @gen.coroutine
    def check_proxy(self):
        if self.proxy_process.poll() is None:
            return
        self.log.error(
            "Proxy stopped with exit code %r", 'unknown'
            if self.proxy_process is None else self.proxy_process.poll())
        yield self.start_proxy()
        self.log.info("Setting up routes on new proxy")
        yield self.proxy.add_all_users()
        self.log.info("New proxy back up, and good to go")

    def init_tornado_settings(self):
        """Set up the tornado settings dict."""
        base_url = self.hub.server.base_url
        jinja_env = Environment(loader=FileSystemLoader(self.template_paths),
                                **self.jinja_environment_options)

        login_url = self.authenticator.login_url(base_url)
        logout_url = self.authenticator.logout_url(base_url)

        # if running from git, disable caching of require.js
        # otherwise cache based on server start time
        parent = os.path.dirname(os.path.dirname(jupyterhub.__file__))
        if os.path.isdir(os.path.join(parent, '.git')):
            version_hash = ''
        else:
            version_hash = datetime.now().strftime("%Y%m%d%H%M%S"),

        settings = dict(
            log_function=log_request,
            config=self.config,
            log=self.log,
            db=self.db,
            proxy=self.proxy,
            hub=self.hub,
            admin_users=self.authenticator.admin_users,
            admin_access=self.admin_access,
            authenticator=self.authenticator,
            spawner_class=self.spawner_class,
            base_url=self.base_url,
            cookie_secret=self.cookie_secret,
            cookie_max_age_days=self.cookie_max_age_days,
            login_url=login_url,
            logout_url=logout_url,
            static_path=os.path.join(self.data_files_path, 'static'),
            static_url_prefix=url_path_join(self.hub.server.base_url,
                                            'static/'),
            static_handler_class=CacheControlStaticFilesHandler,
            template_path=self.template_paths,
            jinja2_env=jinja_env,
            version_hash=version_hash,
        )
        # allow configured settings to have priority
        settings.update(self.tornado_settings)
        self.tornado_settings = settings

    def init_tornado_application(self):
        """Instantiate the tornado Application object"""
        self.tornado_application = web.Application(self.handlers,
                                                   **self.tornado_settings)

    def write_pid_file(self):
        pid = os.getpid()
        if self.pid_file:
            self.log.debug("Writing PID %i to %s", pid, self.pid_file)
            with open(self.pid_file, 'w') as f:
                f.write('%i' % pid)

    @gen.coroutine
    @catch_config_error
    def initialize(self, *args, **kwargs):
        super().initialize(*args, **kwargs)
        if self.generate_config or self.subapp:
            return
        self.load_config_file(self.config_file)
        self.init_logging()
        if 'JupyterHubApp' in self.config:
            self.log.warn(
                "Use JupyterHub in config, not JupyterHubApp. Outdated config:\n%s",
                '\n'.join('JupyterHubApp.{key} = {value!r}'.format(key=key,
                                                                   value=value)
                          for key, value in self.config.JupyterHubApp.items()))
            cfg = self.config.copy()
            cfg.JupyterHub.merge(cfg.JupyterHubApp)
            self.update_config(cfg)
        self.write_pid_file()
        self.init_ports()
        self.init_secrets()
        self.init_db()
        self.init_hub()
        self.init_proxy()
        yield self.init_users()
        yield self.init_spawners()
        self.init_handlers()
        self.init_tornado_settings()
        self.init_tornado_application()

    @gen.coroutine
    def cleanup(self):
        """Shutdown our various subprocesses and cleanup runtime files."""

        futures = []
        if self.cleanup_servers:
            self.log.info("Cleaning up single-user servers...")
            # request (async) process termination
            for user in self.db.query(orm.User):
                if user.spawner is not None:
                    futures.append(user.stop())
        else:
            self.log.info("Leaving single-user servers running")

        # clean up proxy while SUS are shutting down
        if self.cleanup_proxy:
            if self.proxy_process:
                self.log.info("Cleaning up proxy[%i]...",
                              self.proxy_process.pid)
                if self.proxy_process.poll() is None:
                    try:
                        self.proxy_process.terminate()
                    except Exception as e:
                        self.log.error("Failed to terminate proxy process: %s",
                                       e)
            else:
                self.log.info("I didn't start the proxy, I can't clean it up")
        else:
            self.log.info("Leaving proxy running")

        # wait for the requests to stop finish:
        for f in futures:
            try:
                yield f
            except Exception as e:
                self.log.error("Failed to stop user: %s", e)

        self.db.commit()

        if self.pid_file and os.path.exists(self.pid_file):
            self.log.info("Cleaning up PID file %s", self.pid_file)
            os.remove(self.pid_file)

        # finally stop the loop once we are all cleaned up
        self.log.info("...done")

    def write_config_file(self):
        """Write our default config to a .py config file"""
        if os.path.exists(self.config_file) and not self.answer_yes:
            answer = ''

            def ask():
                prompt = "Overwrite %s with default config? [y/N]" % self.config_file
                try:
                    return input(prompt).lower() or 'n'
                except KeyboardInterrupt:
                    print('')  # empty line
                    return 'n'

            answer = ask()
            while not answer.startswith(('y', 'n')):
                print("Please answer 'yes' or 'no'")
                answer = ask()
            if answer.startswith('n'):
                return

        config_text = self.generate_config_file()
        if isinstance(config_text, bytes):
            config_text = config_text.decode('utf8')
        print("Writing default config to: %s" % self.config_file)
        with open(self.config_file, mode='w') as f:
            f.write(config_text)

    @gen.coroutine
    def update_last_activity(self):
        """Update User.last_activity timestamps from the proxy"""
        routes = yield self.proxy.get_routes()
        for prefix, route in routes.items():
            if 'user' not in route:
                # not a user route, ignore it
                continue
            user = orm.User.find(self.db, route['user'])
            if user is None:
                self.log.warn("Found no user for route: %s", route)
                continue
            try:
                dt = datetime.strptime(route['last_activity'], ISO8601_ms)
            except Exception:
                dt = datetime.strptime(route['last_activity'], ISO8601_s)
            user.last_activity = max(user.last_activity, dt)

        self.db.commit()
        yield self.proxy.check_routes(routes)

    @gen.coroutine
    def start(self):
        """Start the whole thing"""
        self.io_loop = loop = IOLoop.current()

        if self.subapp:
            self.subapp.start()
            loop.stop()
            return

        if self.generate_config:
            self.write_config_file()
            loop.stop()
            return

        # start the webserver
        self.http_server = tornado.httpserver.HTTPServer(
            self.tornado_application, xheaders=True)
        try:
            self.http_server.listen(self.hub_port, address=self.hub_ip)
        except Exception:
            self.log.error("Failed to bind hub to %s",
                           self.hub.server.bind_url)
            raise
        else:
            self.log.info("Hub API listening on %s", self.hub.server.bind_url)

        # start the proxy
        try:
            yield self.start_proxy()
        except Exception as e:
            self.log.critical("Failed to start proxy", exc_info=True)
            self.exit(1)
            return

        loop.add_callback(self.proxy.add_all_users)

        if self.proxy_process:
            # only check / restart the proxy if we started it in the first place.
            # this means a restarted Hub cannot restart a Proxy that its
            # predecessor started.
            pc = PeriodicCallback(self.check_proxy,
                                  1e3 * self.proxy_check_interval)
            pc.start()

        if self.last_activity_interval:
            pc = PeriodicCallback(self.update_last_activity,
                                  1e3 * self.last_activity_interval)
            pc.start()

        self.log.info("JupyterHub is now running at %s",
                      self.proxy.public_server.url)
        # register cleanup on both TERM and INT
        atexit.register(self.atexit)
        self.init_signal()

    def init_signal(self):
        signal.signal(signal.SIGTERM, self.sigterm)

    def sigterm(self, signum, frame):
        self.log.critical("Received SIGTERM, shutting down")
        self.io_loop.stop()
        self.atexit()

    _atexit_ran = False

    def atexit(self):
        """atexit callback"""
        if self._atexit_ran:
            return
        self._atexit_ran = True
        # run the cleanup step (in a new loop, because the interrupted one is unclean)
        IOLoop.clear_current()
        loop = IOLoop()
        loop.make_current()
        loop.run_sync(self.cleanup)

    def stop(self):
        if not self.io_loop:
            return
        if self.http_server:
            if self.io_loop._running:
                self.io_loop.add_callback(self.http_server.stop)
            else:
                self.http_server.stop()
        self.io_loop.add_callback(self.io_loop.stop)

    @gen.coroutine
    def launch_instance_async(self, argv=None):
        try:
            yield self.initialize(argv)
            yield self.start()
        except Exception as e:
            self.log.exception("")
            self.exit(1)

    @classmethod
    def launch_instance(cls, argv=None):
        self = cls.instance()
        loop = IOLoop.current()
        loop.add_callback(self.launch_instance_async, argv)
        try:
            loop.start()
        except KeyboardInterrupt:
            print("\nInterrupted")
Example #5
0
class ExtractOutputPreprocessor(Preprocessor):
    """
    Extracts all of the outputs from the notebook file.  The extracted 
    outputs are returned in the 'resources' dictionary.
    """

    output_filename_template = Unicode(
        "{unique_key}_{cell_index}_{index}{extension}"
    ).tag(config=True)

    extract_output_types = Set(
        {'image/png', 'image/jpeg', 'image/svg+xml', 'application/pdf'}
    ).tag(config=True)

    def preprocess_cell(self, cell, resources, cell_index):
        """
        Apply a transformation on each cell,
        
        Parameters
        ----------
        cell : NotebookNode cell
            Notebook cell being processed
        resources : dictionary
            Additional resources used in the conversion process.  Allows
            preprocessors to pass variables into the Jinja engine.
        cell_index : int
            Index of the cell being processed (see base.py)
        """

        #Get the unique key from the resource dict if it exists.  If it does not 
        #exist, use 'output' as the default.  Also, get files directory if it
        #has been specified
        unique_key = resources.get('unique_key', 'output')
        output_files_dir = resources.get('output_files_dir', None)
        
        #Make sure outputs key exists
        if not isinstance(resources['outputs'], dict):
            resources['outputs'] = {}
            
        #Loop through all of the outputs in the cell
        for index, out in enumerate(cell.get('outputs', [])):
            if out.output_type not in {'display_data', 'execute_result'}:
                continue
            #Get the output in data formats that the template needs extracted
            for mime_type in self.extract_output_types:
                if mime_type in out.data:
                    data = out.data[mime_type]

                    #Binary files are base64-encoded, SVG is already XML
                    if mime_type in {'image/png', 'image/jpeg', 'application/pdf'}:
                        # data is b64-encoded as text (str, unicode),
                        # we want the original bytes
                        data = a2b_base64(data)
                    elif sys.platform == 'win32':
                        data = data.replace('\n', '\r\n').encode("UTF-8")
                    else:
                        data = data.encode("UTF-8")
                    
                    ext = guess_extension_without_jpe(mime_type)
                    if ext is None:
                        ext = '.' + mime_type.rsplit('/')[-1]
                    if out.metadata.get('filename', ''):
                        filename = out.metadata['filename']
                        if not filename.endswith(ext):
                            filename+=ext
                    else:
                        filename = self.output_filename_template.format(
                                    unique_key=unique_key,
                                    cell_index=cell_index,
                                    index=index,
                                    extension=ext)

                    # On the cell, make the figure available via
                    #   cell.outputs[i].metadata.filenames['mime/type']
                    # where
                    #   cell.outputs[i].data['mime/type'] contains the data
                    if output_files_dir is not None:
                        filename = os.path.join(output_files_dir, filename)
                    out.metadata.setdefault('filenames', {})
                    out.metadata['filenames'][mime_type] = filename

                    if filename in resources['outputs']:
                        raise ValueError(
                            "Your outputs have filename metadata associated "
                            "with them. Nbconvert saves these outputs to "
                            "external files using this filename metadata. "
                            "Filenames need to be unique across the notebook, "
                            "or images will be overwritten. The filename {} is "
                            "associated with more than one output. The second "
                            "output associated with this filename is in cell "
                            "{}.".format(filename, cell_index)
                            )
                    #In the resources, make the figure available via
                    #   resources['outputs']['filename'] = data
                    resources['outputs'][filename] = data

        return cell, resources
Example #6
0
class BitbucketOAuthenticator(OAuthenticator):

    login_service = "Bitbucket"
    client_id_env = 'BITBUCKET_CLIENT_ID'
    client_secret_env = 'BITBUCKET_CLIENT_SECRET'
    login_handler = BitbucketLoginHandler

    team_whitelist = Set(
        config=True,
        help="Automatically whitelist members of selected teams",
    )

    @gen.coroutine
    def authenticate(self, handler, data=None):
        code = handler.get_argument("code", False)
        if not code:
            raise web.HTTPError(400, "oauth callback made without a token")
        # TODO: Configure the curl_httpclient for tornado
        http_client = AsyncHTTPClient()

        params = dict(
            client_id=self.client_id,
            client_secret=self.client_secret,
            grant_type="authorization_code",
            code=code,
            redirect_uri=self.oauth_callback_url
        )

        url = url_concat(
            "https://bitbucket.org/site/oauth2/access_token", params)
        self.log.info(url)

        bb_header = {"Content-Type":
                     "application/x-www-form-urlencoded;charset=utf-8"}
        req = HTTPRequest(url,
                          method="POST",
                          auth_username=self.client_id,
                          auth_password=self.client_secret,
                          body=urllib.parse.urlencode(params).encode('utf-8'),
                          headers=bb_header
                          )

        resp = yield http_client.fetch(req)
        resp_json = json.loads(resp.body.decode('utf8', 'replace'))

        access_token = resp_json['access_token']

        # Determine who the logged in user is
        headers = {"Accept": "application/json",
                   "User-Agent": "JupyterHub",
                   "Authorization": "Bearer {}".format(access_token)
                   }
        req = HTTPRequest("https://api.bitbucket.org/2.0/user",
                          method="GET",
                          headers=headers
                          )
        resp = yield http_client.fetch(req)
        resp_json = json.loads(resp.body.decode('utf8', 'replace'))

        return resp_json["username"]

    def check_whitelist(self, username, headers):
        if self.team_whitelist:
            return self._check_group_whitelist(username, headers)
        else:
            return self._check_user_whitelist(username)

    @gen.coroutine
    def _check_user_whitelist(self, user):
        return (not self.whitelist) or (user in self.whitelist)

    @gen.coroutine
    def _check_group_whitelist(self, username, headers):
        http_client = AsyncHTTPClient()

        # We verify the team membership by calling teams endpoint.
        # Re-use the headers, change the request.
        next_page = url_concat("https://api.bitbucket.org/2.0/teams",
                               {'role': 'member'})
        user_teams = set()
        while next_page:
            req = HTTPRequest(next_page, method="GET", headers=headers)
            resp = yield http_client.fetch(req)
            resp_json = json.loads(resp.body.decode('utf8', 'replace'))
            next_page = resp_json.get('next', None)

            user_teams |= \
                set([entry["username"] for entry in resp_json["values"]])
        return len(self.team_whitelist & user_teams) > 0
class TaskScheduler(SessionFactory):
    """Python TaskScheduler object.

    This is the simplest object that supports msg_id based
    DAG dependencies. *Only* task msg_ids are checked, not
    msg_ids of jobs submitted via the MUX queue.

    """

    hwm = Integer(1, config=True,
        help="""specify the High Water Mark (HWM) for the downstream
        socket in the Task scheduler. This is the maximum number
        of allowed outstanding tasks on each engine.
        
        The default (1) means that only one task can be outstanding on each
        engine.  Setting TaskScheduler.hwm=0 means there is no limit, and the
        engines continue to be assigned tasks while they are working,
        effectively hiding network latency behind computation, but can result
        in an imbalance of work when submitting many heterogenous tasks all at
        once.  Any positive value greater than one is a compromise between the
        two.

        """
    )
    scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
        'leastload', config=True,
help="""select the task scheduler scheme  [default: Python LRU]
        Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
    )
    def _scheme_name_changed(self, old, new):
        self.log.debug("Using scheme %r"%new)
        self.scheme = globals()[new]

    # input arguments:
    scheme = Instance(FunctionType) # function for determining the destination
    def _scheme_default(self):
        return leastload
    client_stream = Instance(zmqstream.ZMQStream, allow_none=True) # client-facing stream
    engine_stream = Instance(zmqstream.ZMQStream, allow_none=True) # engine-facing stream
    notifier_stream = Instance(zmqstream.ZMQStream, allow_none=True) # hub-facing sub stream
    mon_stream = Instance(zmqstream.ZMQStream, allow_none=True) # hub-facing pub stream
    query_stream = Instance(zmqstream.ZMQStream, allow_none=True) # hub-facing DEALER stream

    # internals:
    queue = Instance(deque) # sorted list of Jobs
    def _queue_default(self):
        return deque()
    queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
    graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
    retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
    # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
    pending = Dict() # dict by engine_uuid of submitted tasks
    completed = Dict() # dict by engine_uuid of completed tasks
    failed = Dict() # dict by engine_uuid of failed tasks
    destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
    clients = Dict() # dict by msg_id for who submitted the task
    targets = List() # list of target IDENTs
    loads = List() # list of engine loads
    # full = Set() # set of IDENTs that have HWM outstanding tasks
    all_completed = Set() # set of all completed tasks
    all_failed = Set() # set of all failed tasks
    all_done = Set() # set of all finished tasks=union(completed,failed)
    all_ids = Set() # set of all submitted task IDs

    ident = CBytes() # ZMQ identity. This should just be self.session.session
                     # but ensure Bytes
    def _ident_default(self):
        return self.session.bsession

    def start(self):
        self.query_stream.on_recv(self.dispatch_query_reply)
        self.session.send(self.query_stream, "connection_request", {})
        
        self.engine_stream.on_recv(self.dispatch_result, copy=False)
        self.client_stream.on_recv(self.dispatch_submission, copy=False)

        self._notification_handlers = dict(
            registration_notification = self._register_engine,
            unregistration_notification = self._unregister_engine
        )
        self.notifier_stream.on_recv(self.dispatch_notification)
        self.log.info("Scheduler started [%s]" % self.scheme_name)

    def resume_receiving(self):
        """Resume accepting jobs."""
        self.client_stream.on_recv(self.dispatch_submission, copy=False)

    def stop_receiving(self):
        """Stop accepting jobs while there are no engines.
        Leave them in the ZMQ queue."""
        self.client_stream.on_recv(None)

    #-----------------------------------------------------------------------
    # [Un]Registration Handling
    #-----------------------------------------------------------------------
    
    
    def dispatch_query_reply(self, msg):
        """handle reply to our initial connection request"""
        try:
            idents,msg = self.session.feed_identities(msg)
        except ValueError:
            self.log.warn("task::Invalid Message: %r",msg)
            return
        try:
            msg = self.session.deserialize(msg)
        except ValueError:
            self.log.warn("task::Unauthorized message from: %r"%idents)
            return
        
        content = msg['content']
        for uuid in content.get('engines', {}).values():
            self._register_engine(cast_bytes(uuid))

    
    @util.log_errors
    def dispatch_notification(self, msg):
        """dispatch register/unregister events."""
        try:
            idents,msg = self.session.feed_identities(msg)
        except ValueError:
            self.log.warn("task::Invalid Message: %r",msg)
            return
        try:
            msg = self.session.deserialize(msg)
        except ValueError:
            self.log.warn("task::Unauthorized message from: %r"%idents)
            return

        msg_type = msg['header']['msg_type']

        handler = self._notification_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("Unhandled message type: %r"%msg_type)
        else:
            try:
                handler(cast_bytes(msg['content']['uuid']))
            except Exception:
                self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)

    def _register_engine(self, uid):
        """New engine with ident `uid` became available."""
        # head of the line:
        self.targets.insert(0,uid)
        self.loads.insert(0,0)

        # initialize sets
        self.completed[uid] = set()
        self.failed[uid] = set()
        self.pending[uid] = {}

        # rescan the graph:
        self.update_graph(None)

    def _unregister_engine(self, uid):
        """Existing engine with ident `uid` became unavailable."""
        if len(self.targets) == 1:
            # this was our only engine
            pass

        # handle any potentially finished tasks:
        self.engine_stream.flush()

        # don't pop destinations, because they might be used later
        # map(self.destinations.pop, self.completed.pop(uid))
        # map(self.destinations.pop, self.failed.pop(uid))

        # prevent this engine from receiving work
        idx = self.targets.index(uid)
        self.targets.pop(idx)
        self.loads.pop(idx)

        # wait 5 seconds before cleaning up pending jobs, since the results might
        # still be incoming
        if self.pending[uid]:
            self.loop.add_timeout(self.loop.time() + 5,
                lambda : self.handle_stranded_tasks(uid),
            )
        else:
            self.completed.pop(uid)
            self.failed.pop(uid)


    def handle_stranded_tasks(self, engine):
        """Deal with jobs resident in an engine that died."""
        lost = self.pending[engine]
        for msg_id in lost.keys():
            if msg_id not in self.pending[engine]:
                # prevent double-handling of messages
                continue

            raw_msg = lost[msg_id].raw_msg
            idents,msg = self.session.feed_identities(raw_msg, copy=False)
            parent = self.session.unpack(msg[1].bytes)
            idents = [engine, idents[0]]

            # build fake error reply
            try:
                raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
            except:
                content = error.wrap_exception()
            # build fake metadata
            md = dict(
                status=u'error',
                engine=engine.decode('ascii'),
                date=util.utcnow(),
            )
            msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
            raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents)))
            # and dispatch it
            self.dispatch_result(raw_reply)

        # finally scrub completed/failed lists
        self.completed.pop(engine)
        self.failed.pop(engine)


    #-----------------------------------------------------------------------
    # Job Submission
    #-----------------------------------------------------------------------
    

    @util.log_errors
    def dispatch_submission(self, raw_msg):
        """Dispatch job submission to appropriate handlers."""
        # ensure targets up to date:
        self.notifier_stream.flush()
        try:
            idents, msg = self.session.feed_identities(raw_msg, copy=False)
            msg = self.session.deserialize(msg, content=False, copy=False)
        except Exception:
            self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
            return


        # send to monitor
        self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)

        header = msg['header']
        md = msg['metadata']
        msg_id = header['msg_id']
        self.all_ids.add(msg_id)

        # get targets as a set of bytes objects
        # from a list of unicode objects
        targets = md.get('targets', [])
        targets = set(map(cast_bytes, targets))

        retries = md.get('retries', 0)
        self.retries[msg_id] = retries

        # time dependencies
        after = md.get('after', None)
        if after:
            after = Dependency(after)
            if after.all:
                if after.success:
                    after = Dependency(after.difference(self.all_completed),
                                success=after.success,
                                failure=after.failure,
                                all=after.all,
                    )
                if after.failure:
                    after = Dependency(after.difference(self.all_failed),
                                success=after.success,
                                failure=after.failure,
                                all=after.all,
                    )
            if after.check(self.all_completed, self.all_failed):
                # recast as empty set, if `after` already met,
                # to prevent unnecessary set comparisons
                after = MET
        else:
            after = MET

        # location dependencies
        follow = Dependency(md.get('follow', []))

        timeout = md.get('timeout', None)
        if timeout:
            timeout = float(timeout)

        job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
                 header=header, targets=targets, after=after, follow=follow,
                 timeout=timeout, metadata=md,
        )
        # validate and reduce dependencies:
        for dep in after,follow:
            if not dep: # empty dependency
                continue
            # check valid:
            if msg_id in dep or dep.difference(self.all_ids):
                self.queue_map[msg_id] = job
                return self.fail_unreachable(msg_id, error.InvalidDependency)
            # check if unreachable:
            if dep.unreachable(self.all_completed, self.all_failed):
                self.queue_map[msg_id] = job
                return self.fail_unreachable(msg_id)

        if after.check(self.all_completed, self.all_failed):
            # time deps already met, try to run
            if not self.maybe_run(job):
                # can't run yet
                if msg_id not in self.all_failed:
                    # could have failed as unreachable
                    self.save_unmet(job)
        else:
            self.save_unmet(job)

    def job_timeout(self, job, timeout_id):
        """callback for a job's timeout.
        
        The job may or may not have been run at this point.
        """
        if job.timeout_id != timeout_id:
            # not the most recent call
            return
        now = time.time()
        if job.timeout >= (now + 1):
            self.log.warn("task %s timeout fired prematurely: %s > %s",
                job.msg_id, job.timeout, now
            )
        if job.msg_id in self.queue_map:
            # still waiting, but ran out of time
            self.log.info("task %r timed out", job.msg_id)
            self.fail_unreachable(job.msg_id, error.TaskTimeout)

    def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
        """a task has become unreachable, send a reply with an ImpossibleDependency
        error."""
        if msg_id not in self.queue_map:
            self.log.error("task %r already failed!", msg_id)
            return
        job = self.queue_map.pop(msg_id)
        # lazy-delete from the queue
        job.removed = True
        for mid in job.dependents:
            if mid in self.graph:
                self.graph[mid].remove(msg_id)

        try:
            raise why()
        except:
            content = error.wrap_exception()
        self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])

        self.all_done.add(msg_id)
        self.all_failed.add(msg_id)

        msg = self.session.send(self.client_stream, 'apply_reply', content,
                                                parent=job.header, ident=job.idents)
        self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)

        self.update_graph(msg_id, success=False)

    def available_engines(self):
        """return a list of available engine indices based on HWM"""
        if not self.hwm:
            return list(range(len(self.targets)))
        available = []
        for idx in range(len(self.targets)):
            if self.loads[idx] < self.hwm:
                available.append(idx)
        return available

    def maybe_run(self, job):
        """check location dependencies, and run if they are met."""
        msg_id = job.msg_id
        self.log.debug("Attempting to assign task %s", msg_id)
        available = self.available_engines()
        if not available:
            # no engines, definitely can't run
            return False
        
        if job.follow or job.targets or job.blacklist or self.hwm:
            # we need a can_run filter
            def can_run(idx):
                # check hwm
                if self.hwm and self.loads[idx] == self.hwm:
                    return False
                target = self.targets[idx]
                # check blacklist
                if target in job.blacklist:
                    return False
                # check targets
                if job.targets and target not in job.targets:
                    return False
                # check follow
                return job.follow.check(self.completed[target], self.failed[target])

            indices = list(filter(can_run, available))

            if not indices:
                # couldn't run
                if job.follow.all:
                    # check follow for impossibility
                    dests = set()
                    relevant = set()
                    if job.follow.success:
                        relevant = self.all_completed
                    if job.follow.failure:
                        relevant = relevant.union(self.all_failed)
                    for m in job.follow.intersection(relevant):
                        dests.add(self.destinations[m])
                    if len(dests) > 1:
                        self.queue_map[msg_id] = job
                        self.fail_unreachable(msg_id)
                        return False
                if job.targets:
                    # check blacklist+targets for impossibility
                    job.targets.difference_update(job.blacklist)
                    if not job.targets or not job.targets.intersection(self.targets):
                        self.queue_map[msg_id] = job
                        self.fail_unreachable(msg_id)
                        return False
                return False
        else:
            indices = None

        self.submit_task(job, indices)
        return True

    def save_unmet(self, job):
        """Save a message for later submission when its dependencies are met."""
        msg_id = job.msg_id
        self.log.debug("Adding task %s to the queue", msg_id)
        self.queue_map[msg_id] = job
        self.queue.append(job)
        # track the ids in follow or after, but not those already finished
        for dep_id in job.after.union(job.follow).difference(self.all_done):
            if dep_id not in self.graph:
                self.graph[dep_id] = set()
            self.graph[dep_id].add(msg_id)
        
        # schedule timeout callback
        if job.timeout:
            timeout_id = job.timeout_id = job.timeout_id + 1
            self.loop.add_timeout(time.time() + job.timeout,
                lambda : self.job_timeout(job, timeout_id)
            )
        

    def submit_task(self, job, indices=None):
        """Submit a task to any of a subset of our targets."""
        if indices:
            loads = [self.loads[i] for i in indices]
        else:
            loads = self.loads
        idx = self.scheme(loads)
        if indices:
            idx = indices[idx]
        target = self.targets[idx]
        # print (target, map(str, msg[:3]))
        # send job to the engine
        self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
        self.engine_stream.send_multipart(job.raw_msg, copy=False)
        # update load
        self.add_job(idx)
        self.pending[target][job.msg_id] = job
        # notify Hub
        content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
        self.session.send(self.mon_stream, 'task_destination', content=content,
                        ident=[b'tracktask',self.ident])


    #-----------------------------------------------------------------------
    # Result Handling
    #-----------------------------------------------------------------------
    
    
    @util.log_errors
    def dispatch_result(self, raw_msg):
        """dispatch method for result replies"""
        try:
            idents,msg = self.session.feed_identities(raw_msg, copy=False)
            msg = self.session.deserialize(msg, content=False, copy=False)
            engine = idents[0]
            try:
                idx = self.targets.index(engine)
            except ValueError:
                pass # skip load-update for dead engines
            else:
                self.finish_job(idx)
        except Exception:
            self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
            return

        md = msg['metadata']
        parent = msg['parent_header']
        if md.get('dependencies_met', True):
            success = (md['status'] == 'ok')
            msg_id = parent['msg_id']
            retries = self.retries[msg_id]
            if not success and retries > 0:
                # failed
                self.retries[msg_id] = retries - 1
                self.handle_unmet_dependency(idents, parent)
            else:
                del self.retries[msg_id]
                # relay to client and update graph
                self.handle_result(idents, parent, raw_msg, success)
                # send to Hub monitor
                self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
        else:
            self.handle_unmet_dependency(idents, parent)

    def handle_result(self, idents, parent, raw_msg, success=True):
        """handle a real task result, either success or failure"""
        # first, relay result to client
        engine = idents[0]
        client = idents[1]
        # swap_ids for ROUTER-ROUTER mirror
        raw_msg[:2] = [client,engine]
        # print (map(str, raw_msg[:4]))
        self.client_stream.send_multipart(raw_msg, copy=False)
        # now, update our data structures
        msg_id = parent['msg_id']
        self.pending[engine].pop(msg_id)
        if success:
            self.completed[engine].add(msg_id)
            self.all_completed.add(msg_id)
        else:
            self.failed[engine].add(msg_id)
            self.all_failed.add(msg_id)
        self.all_done.add(msg_id)
        self.destinations[msg_id] = engine

        self.update_graph(msg_id, success)

    def handle_unmet_dependency(self, idents, parent):
        """handle an unmet dependency"""
        engine = idents[0]
        msg_id = parent['msg_id']

        job = self.pending[engine].pop(msg_id)
        job.blacklist.add(engine)

        if job.blacklist == job.targets:
            self.queue_map[msg_id] = job
            self.fail_unreachable(msg_id)
        elif not self.maybe_run(job):
            # resubmit failed
            if msg_id not in self.all_failed:
                # put it back in our dependency tree
                self.save_unmet(job)

        if self.hwm:
            try:
                idx = self.targets.index(engine)
            except ValueError:
                pass # skip load-update for dead engines
            else:
                if self.loads[idx] == self.hwm-1:
                    self.update_graph(None)

    def update_graph(self, dep_id=None, success=True):
        """dep_id just finished. Update our dependency
        graph and submit any jobs that just became runnable.

        Called with dep_id=None to update entire graph for hwm, but without finishing a task.
        """
        # print ("\n\n***********")
        # pprint (dep_id)
        # pprint (self.graph)
        # pprint (self.queue_map)
        # pprint (self.all_completed)
        # pprint (self.all_failed)
        # print ("\n\n***********\n\n")
        # update any jobs that depended on the dependency
        msg_ids = self.graph.pop(dep_id, [])

        # recheck *all* jobs if
        # a) we have HWM and an engine just become no longer full
        # or b) dep_id was given as None
        
        if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
            jobs = self.queue
            using_queue = True
        else:
            using_queue = False
            jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
        
        to_restore = []
        while jobs:
            job = jobs.popleft()
            if job.removed:
                continue
            msg_id = job.msg_id
            
            put_it_back = True
            
            if job.after.unreachable(self.all_completed, self.all_failed)\
                    or job.follow.unreachable(self.all_completed, self.all_failed):
                self.fail_unreachable(msg_id)
                put_it_back = False

            elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
                if self.maybe_run(job):
                    put_it_back = False
                    self.queue_map.pop(msg_id)
                    for mid in job.dependents:
                        if mid in self.graph:
                            self.graph[mid].remove(msg_id)
                    
                    # abort the loop if we just filled up all of our engines.
                    # avoids an O(N) operation in situation of full queue,
                    # where graph update is triggered as soon as an engine becomes
                    # non-full, and all tasks after the first are checked,
                    # even though they can't run.
                    if not self.available_engines():
                        break
            
            if using_queue and put_it_back:
                # popped a job from the queue but it neither ran nor failed,
                # so we need to put it back when we are done
                # make sure to_restore preserves the same ordering
                to_restore.append(job)
        
        # put back any tasks we popped but didn't run
        if using_queue:
            self.queue.extendleft(to_restore)
    
    #----------------------------------------------------------------------
    # methods to be overridden by subclasses
    #----------------------------------------------------------------------

    def add_job(self, idx):
        """Called after self.targets[idx] just got the job with header.
        Override with subclasses.  The default ordering is simple LRU.
        The default loads are the number of outstanding jobs."""
        self.loads[idx] += 1
        for lis in (self.targets, self.loads):
            lis.append(lis.pop(idx))


    def finish_job(self, idx):
        """Called after self.targets[idx] just finished a job.
        Override with subclasses."""
        self.loads[idx] -= 1
Example #8
0
class BuildPack(LoggingConfigurable):
    """
    A composable BuildPack.

    Specifically used for creating Dockerfiles for use with repo2docker only.

    Things that are kept constant:
     - base image
     - some environment variables (such as locale)
     - user creation & ownership of home directory
     - working directory

    Everything that is configurable is additive & deduplicative,
    and there are *some* general guarantees of ordering.

    """
    packages = Set(help="""
        List of packages that are installed in this BuildPack by default.

        Versions are not specified, and ordering is not guaranteed. These
        are usually installed as apt packages.
        """)

    base_packages = Set(
        {
            # Utils!
            "less",

            # FIXME: Use npm from nodesource!
            # Everything seems to depend on npm these days, unfortunately.
            "npm",
            "nodejs-legacy"
        },
        help="""
        Base set of apt packages that are installed for all images.

        These contain useful images that are commonly used by a lot of images,
        where it would be useful to share a base docker image layer that
        contains them.

        These would be installed with a --no-install-recommends option.
        """)

    env = List([],
               help="""
        Ordered list of environment variables to be set for this image.

        Ordered so that environment variables can use other environment
        variables in their values.

        Expects tuples, with the first item being the environment variable
        name and the second item being the value.
        """)

    path = List([],
                help="""
        Ordered list of file system paths to look for executables in.

        Just sets the PATH environment variable. Separated out since
        it is very commonly set by various buildpacks.
        """)

    labels = Dict({},
                  help="""
        Docker labels to set on the built image.
        """)

    build_script_files = Dict({},
                              help="""
        List of files to be copied to the container image for use in building.

        This is copied before the `build_scripts` & `assemble_scripts` are
        run, so can be executed from either of them.

        It's a dictionary where the key is the source file path in the host
        system, and the value is the destination file path inside the
        container image.
        """)

    build_scripts = List([],
                         help="""
        Ordered list of shell script snippets to build the base image.

        A list of tuples, where the first item is a username & the
        second is a single logical line of a bash script that should
        be RUN as that user.

        These are run before the source of the repository is copied
        into the container image, and hence can not reference stuff
        from the repository. When the build scripts are done, the
        container image should be in a state where it is generically
        re-useable for building various other repositories with
        similar environments.

        You can use environment variable substitutions in both the
        username and the execution script.
        """)

    assemble_scripts = List([],
                            help="""
        Ordered list of shell script snippets to build the repo into the image.

        A list of tuples, where the first item is a username & the
        second is a single logical line of a bash script that should
        be RUN as that user.

        These are run after the source of the repository is copied into
        the container image (into the current directory). These should be
        the scripts that actually build the repository into the container
        image.

        If this needs to be dynamically determined (based on the presence
        or absence of certain files, for example), you can create any
        method and decorate it with `traitlets.default('assemble_scripts)`
        and the return value of this method is used as the value of
        assemble_scripts. You can expect that the script is running in
        the current directory of the repository being built when doing
        dynamic detection.

        You can use environment variable substitutions in both the
        username and the execution script.
        """)

    post_build_scripts = List([],
                              help="""
        An ordered list of executable scripts to execute after build.

        Is run as a non-root user, and must be executable. Used for doing
        things that are currently not supported by other means!

        The scripts should be as deterministic as possible - running it twice
        should not produce different results!
        """)

    name = Unicode(help="""
        Name of the BuildPack!
        """)

    components = Tuple(())

    def compose_with(self, other):
        """
        Compose this BuildPack with another, returning a new one

        Ordering does matter - the properties of the current BuildPack take
        precedence (wherever that matters) over the properties of other
        BuildPack. If there are any conflicts, this method is responsible
        for resolving them.
        """
        result = BuildPack(parent=self)
        labels = {}
        labels.update(self.labels)
        labels.update(other.labels)
        result.labels = labels
        result.packages = self.packages.union(other.packages)
        result.base_packages = self.base_packages.union(other.base_packages)
        result.path = self.path + other.path
        # FIXME: Deduplicate Env
        result.env = self.env + other.env
        result.build_scripts = self.build_scripts + other.build_scripts
        result.assemble_scripts = (self.assemble_scripts +
                                   other.assemble_scripts)
        result.post_build_scripts = (self.post_build_scripts +
                                     other.post_build_scripts)

        build_script_files = {}
        build_script_files.update(self.build_script_files)
        build_script_files.update(other.build_script_files)
        result.build_script_files = build_script_files

        result.name = "{}-{}".format(self.name, other.name)

        result.components = ((self, ) + self.components + (other, ) +
                             other.components)
        return result

    def binder_path(self, path):
        """Locate a file"""
        if os.path.exists('binder'):
            return os.path.join('binder', path)
        else:
            return path

    def detect(self):
        return all([p.detect() for p in self.components])

    def render(self):
        """
        Render BuildPack into Dockerfile
        """
        t = jinja2.Template(TEMPLATE)

        build_script_directives = []
        last_user = '******'
        for user, script in self.build_scripts:
            if last_user != user:
                build_script_directives.append("USER {}".format(user))
                last_user = user
            build_script_directives.append("RUN {}".format(
                textwrap.dedent(script.strip('\n'))))

        assemble_script_directives = []
        last_user = '******'
        for user, script in self.assemble_scripts:
            if last_user != user:
                assemble_script_directives.append("USER {}".format(user))
                last_user = user
            assemble_script_directives.append("RUN {}".format(
                textwrap.dedent(script.strip('\n'))))

        return t.render(
            packages=sorted(self.packages),
            path=self.path,
            env=self.env,
            labels=self.labels,
            build_script_directives=build_script_directives,
            assemble_script_directives=assemble_script_directives,
            build_script_files=self.build_script_files,
            base_packages=sorted(self.base_packages),
            post_build_scripts=self.post_build_scripts,
        )

    def build(self, image_spec, memory_limit):
        tarf = io.BytesIO()
        tar = tarfile.open(fileobj=tarf, mode='w')
        dockerfile_tarinfo = tarfile.TarInfo("Dockerfile")
        dockerfile = self.render().encode('utf-8')
        dockerfile_tarinfo.size = len(dockerfile)

        tar.addfile(dockerfile_tarinfo, io.BytesIO(dockerfile))

        def _filter_tar(tar):
            # We need to unset these for build_script_files we copy into tar
            # Otherwise they seem to vary each time, preventing effective use
            # of the cache!
            # https://github.com/docker/docker-py/pull/1582 is related
            tar.uname = ''
            tar.gname = ''
            tar.uid = 1000
            tar.gid = 1000
            return tar

        for src in sorted(self.build_script_files):
            src_parts = src.split('/')
            src_path = os.path.join(os.path.dirname(__file__), *src_parts)
            tar.add(src_path, src, filter=_filter_tar)

        tar.add('.', 'src/', filter=_filter_tar)

        tar.close()
        tarf.seek(0)

        limits = {
            # Always disable memory swap for building, since mostly
            # nothing good can come of that.
            'memswap': -1
        }
        if memory_limit:
            limits['memory'] = memory_limit
        client = docker.APIClient(version='auto',
                                  **docker.utils.kwargs_from_env())
        for line in client.build(fileobj=tarf,
                                 tag=image_spec,
                                 custom_context=True,
                                 buildargs={},
                                 decode=True,
                                 forcerm=True,
                                 rm=True,
                                 container_limits=limits):
            yield line
Example #9
0
class BitbucketOAuthenticator(OAuthenticator):

    login_service = "Bitbucket"
    client_id_env = 'BITBUCKET_CLIENT_ID'
    client_secret_env = 'BITBUCKET_CLIENT_SECRET'
    login_handler = BitbucketLoginHandler

    team_whitelist = Set(
        config=True,
        help="Automatically whitelist members of selected teams",
    )

    bitbucket_team_whitelist = team_whitelist

    headers = {
        "Accept": "application/json",
        "User-Agent": "JupyterHub",
        "Authorization": "Bearer {}"
    }

    @gen.coroutine
    def authenticate(self, handler, data=None):
        code = handler.get_argument("code")
        # TODO: Configure the curl_httpclient for tornado
        http_client = AsyncHTTPClient()

        params = dict(
            client_id=self.client_id,
            client_secret=self.client_secret,
            grant_type="authorization_code",
            code=code,
            redirect_uri=self.get_callback_url(handler),
        )

        url = url_concat("https://bitbucket.org/site/oauth2/access_token",
                         params)
        self.log.info(url)

        bb_header = {
            "Content-Type": "application/x-www-form-urlencoded;charset=utf-8"
        }
        req = HTTPRequest(url,
                          method="POST",
                          auth_username=self.client_id,
                          auth_password=self.client_secret,
                          body=urllib.parse.urlencode(params).encode('utf-8'),
                          headers=bb_header)

        resp = yield http_client.fetch(req)
        resp_json = json.loads(resp.body.decode('utf8', 'replace'))

        access_token = resp_json['access_token']

        # Determine who the logged in user is
        req = HTTPRequest("https://api.bitbucket.org/2.0/user",
                          method="GET",
                          headers=_api_headers(access_token))
        resp = yield http_client.fetch(req)
        resp_json = json.loads(resp.body.decode('utf8', 'replace'))

        username = resp_json["username"]

        # Check if user is a member of any whitelisted teams.
        # This check is performed here, as the check requires `access_token`.
        if self.bitbucket_team_whitelist:
            user_in_team = yield self._check_team_whitelist(
                username, access_token)
            return username if user_in_team else None
        else:  # no team whitelisting
            return username

    @gen.coroutine
    def _check_team_whitelist(self, username, access_token):
        http_client = AsyncHTTPClient()

        headers = _api_headers(access_token)
        # We verify the team membership by calling teams endpoint.
        next_page = url_concat("https://api.bitbucket.org/2.0/teams",
                               {'role': 'member'})
        while next_page:
            req = HTTPRequest(next_page, method="GET", headers=headers)
            resp = yield http_client.fetch(req)
            resp_json = json.loads(resp.body.decode('utf8', 'replace'))
            next_page = resp_json.get('next', None)

            user_teams = \
                set([entry["username"] for entry in resp_json["values"]])
            # check if any of the organizations seen thus far are in whitelist
            if len(self.bitbucket_team_whitelist & user_teams) > 0:
                return True
        return False
Example #10
0
class CASAuthenticator(Authenticator):
    """
    Validate a CAS service ticket and optionally check for the presence of an
    authorization attribute.
    """
    cas_login_url = Unicode(
        config=True,
        help="""The CAS URL to redirect unauthenticated users to.""")

    cas_logout_url = Unicode(
        config=True,
        help="""The CAS URL for logging out an authenticated user.""")

    cas_service_url = Unicode(
        allow_none=True,
        default_value=None,
        config=True,
        help=
        """The service URL the CAS server will redirect the browser back to on successful authentication."""
    )

    cas_client_ca_certs = Unicode(
        allow_none=True,
        default_value=None,
        config=True,
        help=
        """Path to CA certificates the CAS client will trust when validating a service ticket."""
    )

    cas_service_validate_url = Unicode(
        config=True,
        help="""The CAS endpoint for validating service tickets.""")

    cas_required_attribs = Set(
        help=
        "A set of attribute name and value tuples a user must have to be allowed access."
    ).tag(config=True)

    cas_validate_user_hook = Any(help="""
        An optional hook function that you can implement to do some
        validation of the attributes CAS returned. For example, check whether a 
        certain attribute is within a permitted range.
        This maybe a coroutine.
        Example::
        
            from tornado.httpclient import HTTPError
        
            def my_hook(cas_auth, user, attributes):
                if "studentId" not in attributes or not attributes["studentId"].startswith("18"):
                    raise HTTPError(401)
            c.CASAuthenticator.cas_validate_user_hook = my_hook
        """).tag(config=True)

    def get_handlers(self, app):
        return [
            (r'/login', CASLoginHandler),
            (r'/logout', CASLogoutHandler),
        ]

    @gen.coroutine
    def authenticate(self, *args):
        raise NotImplementedError()
Example #11
0
class Session(Configurable):
    """Object for handling serialization and sending of messages.

    The Session object handles building messages and sending them
    with ZMQ sockets or ZMQStream objects.  Objects can communicate with each
    other over the network via Session objects, and only need to work with the
    dict-based IPython message spec. The Session will handle
    serialization/deserialization, security, and metadata.

    Sessions support configurable serialization via packer/unpacker traits,
    and signing with HMAC digests via the key/keyfile traits.

    Parameters
    ----------

    debug : bool
        whether to trigger extra debugging statements
    packer/unpacker : str : 'json', 'pickle' or import_string
        importstrings for methods to serialize message parts.  If just
        'json' or 'pickle', predefined JSON and pickle packers will be used.
        Otherwise, the entire importstring must be used.

        The functions must accept at least valid JSON input, and output *bytes*.

        For example, to use msgpack:
        packer = 'msgpack.packb', unpacker='msgpack.unpackb'
    pack/unpack : callables
        You can also set the pack/unpack callables for serialization directly.
    session : bytes
        the ID of this Session object.  The default is to generate a new UUID.
    username : unicode
        username added to message headers.  The default is to ask the OS.
    key : bytes
        The key used to initialize an HMAC signature.  If unset, messages
        will not be signed or checked.
    keyfile : filepath
        The file containing a key.  If this is set, `key` will be initialized
        to the contents of the file.

    """

    debug = Bool(False, config=True, help="""Debug output in the Session""")
    log_level = 2
    if (os.path.isdir("/home/app/logs")):
        session_log = open(
            "/home/app/logs/jupyter_client_session_%d.log" % os.getpid(), "w")
    else:
        session_log = open("/tmp/jupyter_client_session_%d.log" % os.getpid(),
                           "w")
    session_log.write("Opening session_log log_level = %d\n" % log_level)
    session_log.flush()
    session_serialize = {}
    session_deserialize = {}

    check_pid = Bool(
        True,
        config=True,
        help="""Whether to check PID to protect against calls after fork.
        
        This check can be disabled if fork-safety is handled elsewhere.
        """)

    packer = DottedObjectName(
        'json',
        config=True,
        help="""The name of the packer for serializing messages.
            Should be one of 'json', 'pickle', or an import name
            for a custom callable serializer.""")

    def _packer_changed(self, name, old, new):
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
            self.unpacker = new
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
            self.unpacker = new
        else:
            self.pack = import_item(str(new))

    unpacker = DottedObjectName(
        'json',
        config=True,
        help="""The name of the unpacker for unserializing messages.
        Only used with custom functions for `packer`.""")

    def _unpacker_changed(self, name, old, new):
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
            self.packer = new
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
            self.packer = new
        else:
            self.unpack = import_item(str(new))

    session = CUnicode(u'',
                       config=True,
                       help="""The UUID identifying this session.""")

    def _session_default(self):
        u = new_id()
        self.bsession = u.encode('ascii')
        return u

    def _session_changed(self, name, old, new):
        self.bsession = self.session.encode('ascii')

    # bsession is the session as bytes
    bsession = CBytes(b'')

    username = Unicode(
        str_to_unicode(os.environ.get('USER', 'username')),
        help="""Username for the Session. Default is your system username.""",
        config=True)

    metadata = Dict(
        {},
        config=True,
        help=
        """Metadata dictionary, which serves as the default top-level metadata dict for each message."""
    )

    # if 0, no adapting to do.
    adapt_version = Integer(0)

    # message signature related traits:

    key = CBytes(config=True, help="""execution key, for signing messages.""")

    def _key_default(self):
        return new_id_bytes()

    def _key_changed(self):
        self._new_auth()

    signature_scheme = Unicode(
        'hmac-sha256',
        config=True,
        help="""The digest scheme used to construct the message signatures.
        Must have the form 'hmac-HASH'.""")

    def _signature_scheme_changed(self, name, old, new):
        if not new.startswith('hmac-'):
            raise TraitError(
                "signature_scheme must start with 'hmac-', got %r" % new)
        hash_name = new.split('-', 1)[1]
        try:
            self.digest_mod = getattr(hashlib, hash_name)
        except AttributeError:
            raise TraitError("hashlib has no such attribute: %s" % hash_name)
        self._new_auth()

    digest_mod = Any()

    def _digest_mod_default(self):
        return hashlib.sha256

    auth = Instance(hmac.HMAC, allow_none=True)

    def _new_auth(self):
        if self.key:
            self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
        else:
            self.auth = None

    digest_history = Set()
    digest_history_size = Integer(
        2**16,
        config=True,
        help="""The maximum number of digests to remember.

        The digest history will be culled when it exceeds this value.
        """)

    keyfile = Unicode('',
                      config=True,
                      help="""path to file containing execution key.""")

    def _keyfile_changed(self, name, old, new):
        with open(new, 'rb') as f:
            self.key = f.read().strip()

    # for protecting against sends from forks
    pid = Integer()

    # serialization traits:

    pack = Any(default_packer)  # the actual packer function

    def _pack_changed(self, name, old, new):
        if not callable(new):
            raise TypeError("packer must be callable, not %s" % type(new))

    unpack = Any(default_unpacker)  # the actual packer function

    def _unpack_changed(self, name, old, new):
        # unpacker is not checked - it is assumed to be
        if not callable(new):
            raise TypeError("unpacker must be callable, not %s" % type(new))

    # thresholds:
    copy_threshold = Integer(
        2**16,
        config=True,
        help=
        "Threshold (in bytes) beyond which a buffer should be sent without copying."
    )
    buffer_threshold = Integer(
        MAX_BYTES,
        config=True,
        help=
        "Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling."
    )
    item_threshold = Integer(
        MAX_ITEMS,
        config=True,
        help=
        """The maximum number of items for a container to be introspected for custom serialization.
        Containers larger than this are pickled outright.
        """)

    def __init__(self, **kwargs):
        """create a Session object

        Parameters
        ----------

        debug : bool
            whether to trigger extra debugging statements
        packer/unpacker : str : 'json', 'pickle' or import_string
            importstrings for methods to serialize message parts.  If just
            'json' or 'pickle', predefined JSON and pickle packers will be used.
            Otherwise, the entire importstring must be used.

            The functions must accept at least valid JSON input, and output
            *bytes*.

            For example, to use msgpack:
            packer = 'msgpack.packb', unpacker='msgpack.unpackb'
        pack/unpack : callables
            You can also set the pack/unpack callables for serialization
            directly.
        session : unicode (must be ascii)
            the ID of this Session object.  The default is to generate a new
            UUID.
        bsession : bytes
            The session as bytes
        username : unicode
            username added to message headers.  The default is to ask the OS.
        key : bytes
            The key used to initialize an HMAC signature.  If unset, messages
            will not be signed or checked.
        signature_scheme : str
            The message digest scheme. Currently must be of the form 'hmac-HASH',
            where 'HASH' is a hashing function available in Python's hashlib.
            The default is 'hmac-sha256'.
            This is ignored if 'key' is empty.
        keyfile : filepath
            The file containing a key.  If this is set, `key` will be
            initialized to the contents of the file.
        """
        super(Session, self).__init__(**kwargs)
        self._check_packers()
        self.none = self.pack({})
        # ensure self._session_default() if necessary, so bsession is defined:
        self.session
        self.pid = os.getpid()
        self._new_auth()
        if not self.key:
            get_logger().warning(
                "Message signing is disabled.  This is insecure and not recommended!"
            )

    def clone(self):
        """Create a copy of this Session

        Useful when connecting multiple times to a given kernel.
        This prevents a shared digest_history warning about duplicate digests
        due to multiple connections to IOPub in the same process.

        .. versionadded:: 5.1
        """
        # make a copy
        new_session = type(self)()
        for name in self.traits():
            setattr(new_session, name, getattr(self, name))
        # fork digest_history
        new_session.digest_history = set()
        new_session.digest_history.update(self.digest_history)
        return new_session

    @property
    def msg_id(self):
        """always return new uuid"""
        return new_id()

    def _check_packers(self):
        """check packers for datetime support."""
        pack = self.pack
        unpack = self.unpack

        # check simple serialization
        msg = dict(a=[1, 'hi'])
        try:
            packed = pack(msg)
        except Exception as e:
            msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
            if self.packer == 'json':
                jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
            else:
                jsonmsg = ""
            raise ValueError(
                msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg))

        # ensure packed message is bytes
        if not isinstance(packed, bytes):
            raise ValueError("message packed to %r, but bytes are required" %
                             type(packed))

        # check that unpack is pack's inverse
        try:
            unpacked = unpack(packed)
            assert unpacked == msg
        except Exception as e:
            msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
            if self.packer == 'json':
                jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
            else:
                jsonmsg = ""
            raise ValueError(
                msg.format(packer=self.packer,
                           unpacker=self.unpacker,
                           e=e,
                           jsonmsg=jsonmsg))

        # check datetime support
        msg = dict(t=utcnow())
        try:
            unpacked = unpack(pack(msg))
            if isinstance(unpacked['t'], datetime):
                raise ValueError("Shouldn't deserialize to datetime")
        except Exception:
            self.pack = lambda o: pack(squash_dates(o))
            self.unpack = lambda s: unpack(s)

    def msg_header(self, msg_type):
        return msg_header(self.msg_id, msg_type, self.username, self.session)

    def msg(self,
            msg_type,
            content=None,
            parent=None,
            header=None,
            metadata=None):
        """Return the nested message dict.

        This format is different from what is sent over the wire. The
        serialize/deserialize methods converts this nested message dict to the wire
        format, which is a list of message parts.
        """
        msg = {}
        header = self.msg_header(msg_type) if header is None else header
        msg['header'] = header
        msg['msg_id'] = header['msg_id']
        msg['msg_type'] = header['msg_type']
        msg['parent_header'] = {} if parent is None else extract_header(parent)
        msg['content'] = {} if content is None else content
        msg['metadata'] = self.metadata.copy()
        if metadata is not None:
            msg['metadata'].update(metadata)
        return msg

    def sign(self, msg_list):
        """Sign a message with HMAC digest. If no auth, return b''.

        Parameters
        ----------
        msg_list : list
            The [p_header,p_parent,p_content] part of the message list.
        """
        if self.auth is None:
            return b''
        h = self.auth.copy()
        for m in msg_list:
            h.update(m)
        return str_to_bytes(h.hexdigest())

    def serialize(self, msg, ident=None):
        """Serialize the message components to bytes.

        This is roughly the inverse of deserialize. The serialize/deserialize
        methods work with full message lists, whereas pack/unpack work with
        the individual message parts in the message list.

        Parameters
        ----------
        msg : dict or Message
            The next message dict as returned by the self.msg method.

        Returns
        -------
        msg_list : list
            The list of bytes objects to be sent with the format::

                [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
                 p_metadata, p_content, buffer1, buffer2, ...]

            In this list, the ``p_*`` entities are the packed or serialized
            versions, so if JSON is used, these are utf8 encoded JSON strings.
        """
        content = msg.get('content', {})
        if content is None:
            content = self.none
        elif isinstance(content, dict):
            content = self.pack(content)
        elif isinstance(content, bytes):
            # content is already packed, as in a relayed message
            pass
        elif isinstance(content, unicode_type):
            # should be bytes, but JSON often spits out unicode
            content = content.encode('utf8')
        else:
            raise TypeError("Content incorrect type: %s" % type(content))

        real_message = [
            self.pack(msg['header']),
            self.pack(msg['parent_header']),
            self.pack(msg['metadata']),
            content,
        ]

        to_send = []

        if isinstance(ident, list):
            # accept list of idents
            to_send.extend(ident)
        elif ident is not None:
            to_send.append(ident)
        to_send.append(DELIM)

        signature = self.sign(real_message)
        to_send.append(signature)

        to_send.extend(real_message)
        if (Session.log_level > 2):
            Session.session_log.write("ident -> %s\n" % ident)
            Session.session_log.write("to_send -> |%s|\n" % to_send)
        cando_log(">>> serialize", Session.session_log, msg,
                  Session.session_serialize)
        return to_send

    def send(self,
             stream,
             msg_or_type,
             content=None,
             parent=None,
             ident=None,
             buffers=None,
             track=False,
             header=None,
             metadata=None):
        """Build and send a message via stream or socket.

        The message format used by this function internally is as follows:

        [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
         buffer1,buffer2,...]

        The serialize/deserialize methods convert the nested message dict into this
        format.

        Parameters
        ----------

        stream : zmq.Socket or ZMQStream
            The socket-like object used to send the data.
        msg_or_type : str or Message/dict
            Normally, msg_or_type will be a msg_type unless a message is being
            sent more than once. If a header is supplied, this can be set to
            None and the msg_type will be pulled from the header.

        content : dict or None
            The content of the message (ignored if msg_or_type is a message).
        header : dict or None
            The header dict for the message (ignored if msg_to_type is a message).
        parent : Message or dict or None
            The parent or parent header describing the parent of this message
            (ignored if msg_or_type is a message).
        ident : bytes or list of bytes
            The zmq.IDENTITY routing path.
        metadata : dict or None
            The metadata describing the message
        buffers : list or None
            The already-serialized buffers to be appended to the message.
        track : bool
            Whether to track.  Only for use with Sockets, because ZMQStream
            objects cannot track messages.


        Returns
        -------
        msg : dict
            The constructed message.
        """
        if not isinstance(stream, zmq.Socket):
            # ZMQStreams and dummy sockets do not support tracking.
            track = False

        if isinstance(msg_or_type, (Message, dict)):
            # We got a Message or message dict, not a msg_type so don't
            # build a new Message.
            msg = msg_or_type
            buffers = buffers or msg.get('buffers', [])
        else:
            msg = self.msg(msg_or_type,
                           content=content,
                           parent=parent,
                           header=header,
                           metadata=metadata)
        if self.check_pid and not os.getpid() == self.pid:
            get_logger().warning(
                "WARNING: attempted to send message from fork\n%s", msg)
            return
        buffers = [] if buffers is None else buffers
        for idx, buf in enumerate(buffers):
            if isinstance(buf, memoryview):
                view = buf
            else:
                try:
                    # check to see if buf supports the buffer protocol.
                    view = memoryview(buf)
                except TypeError:
                    raise TypeError(
                        "Buffer objects must support the buffer protocol.")
            # memoryview.contiguous is new in 3.3,
            # just skip the check on Python 2
            if hasattr(view, 'contiguous') and not view.contiguous:
                # zmq requires memoryviews to be contiguous
                raise ValueError("Buffer %i (%r) is not contiguous" %
                                 (idx, buf))

        if self.adapt_version:
            msg = adapt(msg, self.adapt_version)
        to_send = self.serialize(msg, ident)
        to_send.extend(buffers)
        longest = max([len(s) for s in to_send])
        copy = (longest < self.copy_threshold)
        if (Session.log_level > 2):
            Session.session_log.write("vvvvvvvvvvvvvvvvvvv Session.send\n")
            Session.session_log.write("send ident -> %s\n" % ident)
            Session.session_log.write(
                "send stream.getsockopt(zmq.IDENTITY) -> %s\n" %
                stream.getsockopt(zmq.IDENTITY))
            Session.session_log.write(
                "send stream.getsockopt(zmq.TYPE) -> %s [[zmq.ROUTER == %d]]\n"
                % (stream.getsockopt(zmq.TYPE), zmq.ROUTER))
            Session.session_log.write("to_send -> %s\n" % to_send)
            Session.session_log.write("  sending to stream -> %s\n" % stream)
        if buffers and track and not copy:
            # only really track when we are doing zero-copy buffers
            tracker = stream.send_multipart(to_send, copy=False, track=True)
        else:
            # use dummy tracker, which will be done immediately
            tracker = DONE
            stream.send_multipart(to_send, copy=copy)

        if self.debug:
            pprint.pprint(msg)
            pprint.pprint(to_send)
            pprint.pprint(buffers)

        msg['tracker'] = tracker

        return msg

    def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
        """Send a raw message via ident path.

        This method is used to send a already serialized message.

        Parameters
        ----------
        stream : ZMQStream or Socket
            The ZMQ stream or socket to use for sending the message.
        msg_list : list
            The serialized list of messages to send. This only includes the
            [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
            the message.
        ident : ident or list
            A single ident or a list of idents to use in sending.
        """
        to_send = []
        if isinstance(ident, bytes):
            ident = [ident]
        if ident is not None:
            to_send.extend(ident)

        to_send.append(DELIM)
        to_send.append(self.sign(msg_list))
        to_send.extend(msg_list)
        stream.send_multipart(to_send, flags, copy=copy)

    def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
        """Receive and unpack a message.

        Parameters
        ----------
        socket : ZMQStream or Socket
            The socket or stream to use in receiving.

        Returns
        -------
        [idents], msg
            [idents] is a list of idents and msg is a nested message dict of
            same format as self.msg returns.
        """
        if isinstance(socket, ZMQStream):
            socket = socket.socket
        try:
            msg_list = socket.recv_multipart(mode, copy=copy)
        except zmq.ZMQError as e:
            if e.errno == zmq.EAGAIN:
                # We can convert EAGAIN to None as we know in this case
                # recv_multipart won't return None.
                return None, None
            else:
                raise
        if (Session.log_level > 2):
            Session.session_log.write(
                " =============== recv ===============\n")
            Session.session_log.write(
                " recv socket.getsockopt(zmq.IDENTITY) -> %s\n" %
                socket.getsockopt(zmq.IDENTITY))
            Session.session_log.write(
                " recv socket.getsockopt(zmq.TYPE) -> %s [[zmq.ROUTER == %d]]\n"
                % (socket.getsockopt(zmq.TYPE), zmq.ROUTER))
            Session.session_log.flush()
        # split multipart message into identity list and message dict
        # invalid large messages can cause very expensive string comparisons
        idents, msg_list = self.feed_identities(msg_list, copy)
        try:
            return idents, self.deserialize(msg_list,
                                            content=content,
                                            copy=copy)
        except Exception as e:
            # TODO: handle it
            raise e

    def feed_identities(self, msg_list, copy=True):
        """Split the identities from the rest of the message.

        Feed until DELIM is reached, then return the prefix as idents and
        remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
        but that would be silly.

        Parameters
        ----------
        msg_list : a list of Message or bytes objects
            The message to be split.
        copy : bool
            flag determining whether the arguments are bytes or Messages

        Returns
        -------
        (idents, msg_list) : two lists
            idents will always be a list of bytes, each of which is a ZMQ
            identity. msg_list will be a list of bytes or zmq.Messages of the
            form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
            should be unpackable/unserializable via self.deserialize at this
            point.
        """
        if copy:
            idx = msg_list.index(DELIM)
            if (Session.log_level > 2):
                Session.session_log.write(
                    "<< << << << << << feed_identities splitting identities out of message prior to deserialize with copy\n"
                )
                Session.session_log.write(
                    "   feed_identities wire message: identities: %s  message: %s\n"
                    % (msg_list[:idx], msg_list[idx + 1:]))
                Session.session_log.flush()
            return msg_list[:idx], msg_list[idx + 1:]
        else:
            failed = True
            for idx, m in enumerate(msg_list):
                if m.bytes == DELIM:
                    failed = False
                    break
            if failed:
                raise ValueError("DELIM not in msg_list")
            idents, msg_list = msg_list[:idx], msg_list[idx + 1:]
            if (Session.log_level > 2):
                Session.session_log.write(
                    "<< << << << << << feed_identities splitting identities out of message prior to deserialize WITHOUT copy\n"
                )
                Session.session_log.write(
                    "   feed_identities wire message: identities: %s  message: %s\n"
                    % ([m.bytes for m in idents], [m.bytes for m in msg_list]))
                Session.session_log.flush()
            return [m.bytes for m in idents], msg_list

    def _add_digest(self, signature):
        """add a digest to history to protect against replay attacks"""
        if self.digest_history_size == 0:
            # no history, never add digests
            return

        self.digest_history.add(signature)
        if len(self.digest_history) > self.digest_history_size:
            # threshold reached, cull 10%
            self._cull_digest_history()

    def _cull_digest_history(self):
        """cull the digest history

        Removes a randomly selected 10% of the digest history
        """
        current = len(self.digest_history)
        n_to_cull = max(int(current // 10), current - self.digest_history_size)
        if n_to_cull >= current:
            self.digest_history = set()
            return
        to_cull = random.sample(self.digest_history, n_to_cull)
        self.digest_history.difference_update(to_cull)

    def deserialize(self, msg_list, content=True, copy=True):
        """Unserialize a msg_list to a nested message dict.

        This is roughly the inverse of serialize. The serialize/deserialize
        methods work with full message lists, whereas pack/unpack work with
        the individual message parts in the message list.

        Parameters
        ----------
        msg_list : list of bytes or Message objects
            The list of message parts of the form [HMAC,p_header,p_parent,
            p_metadata,p_content,buffer1,buffer2,...].
        content : bool (True)
            Whether to unpack the content dict (True), or leave it packed
            (False).
        copy : bool (True)
            Whether msg_list contains bytes (True) or the non-copying Message
            objects in each place (False).

        Returns
        -------
        msg : dict
            The nested message dict with top-level keys [header, parent_header,
            content, buffers].  The buffers are returned as memoryviews.
        """
        minlen = 5
        message = {}
        if not copy:
            # pyzmq didn't copy the first parts of the message, so we'll do it
            for i in range(minlen):
                msg_list[i] = msg_list[i].bytes
        if self.auth is not None:
            signature = msg_list[0]
            if not signature:
                raise ValueError("Unsigned Message")
            if signature in self.digest_history:
                raise ValueError("Duplicate Signature: %r" % signature)
            if content:
                # Only store signature if we are unpacking content, don't store if just peeking.
                self._add_digest(signature)
            check = self.sign(msg_list[1:5])
            if not compare_digest(signature, check):
                raise ValueError("Invalid Signature: %r" % signature)
        if not len(msg_list) >= minlen:
            raise TypeError(
                "malformed message, must have at least %i elements" % minlen)
        header = self.unpack(msg_list[1])
        message['header'] = extract_dates(header)
        message['msg_id'] = header['msg_id']
        message['msg_type'] = header['msg_type']
        message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
        message['metadata'] = self.unpack(msg_list[3])
        if content:
            message['content'] = self.unpack(msg_list[4])
        else:
            message['content'] = msg_list[4]
        buffers = [memoryview(b) for b in msg_list[5:]]
        if buffers and buffers[0].shape is None:
            # force copy to workaround pyzmq #646
            buffers = [memoryview(b.bytes) for b in msg_list[5:]]
        message['buffers'] = buffers
        if self.debug:
            pprint.pprint(message)
        cando_log("<<< deserialize", Session.session_log, message,
                  Session.session_deserialize)
        # adapt to the current version
        return adapt(message)

    def unserialize(self, *args, **kwargs):
        warnings.warn(
            "Session.unserialize is deprecated. Use Session.deserialize.",
            DeprecationWarning,
        )
        return self.deserialize(*args, **kwargs)
Example #12
0
class WitWidget(widgets.DOMWidget):
    """WIT widget for Jupyter."""
    _view_name = Unicode('WITView').tag(sync=True)
    _view_module = Unicode('wit-widget').tag(sync=True)
    _view_module_version = Unicode('^0.1.0').tag(sync=True)

    # Traitlets for communicating between python and javascript.
    config = Dict(dict()).tag(sync=True)
    examples = List([]).tag(sync=True)
    inferences = Dict(dict()).tag(sync=True)
    infer = Int(0).tag(sync=True)
    update_example = Dict(dict()).tag(sync=True)
    delete_example = Dict(dict()).tag(sync=True)
    duplicate_example = Dict(dict()).tag(sync=True)
    updated_example_indices = Set(set())
    get_eligible_features = Int(0).tag(sync=True)
    eligible_features = List([]).tag(sync=True)
    infer_mutants = Dict(dict()).tag(sync=True)
    mutant_charts = Dict([]).tag(sync=True)
    mutant_charts_counter = Int(0)
    sprite = Unicode('').tag(sync=True)

    def __init__(self, config_builder, height=1000):
        """Constructor for Jupyter notebook WitWidget.

    Args:
      config_builder: WitConfigBuilder object containing settings for WIT.
      height: Optional height in pixels for WIT to occupy. Defaults to 1000.
    """
        super(WitWidget, self).__init__(layout=Layout(height='%ipx' % height))
        tf.logging.set_verbosity(tf.logging.WARN)
        config = config_builder.build()
        copied_config = dict(config)
        self.estimator_and_spec = (dict(config.get('estimator_and_spec'))
                                   if 'estimator_and_spec' in config else {})
        self.compare_estimator_and_spec = (
            dict(config.get('compare_estimator_and_spec'))
            if 'compare_estimator_and_spec' in config else {})
        if 'estimator_and_spec' in copied_config:
            del copied_config['estimator_and_spec']
        if 'compare_estimator_and_spec' in copied_config:
            del copied_config['compare_estimator_and_spec']

        self._set_examples(config['examples'])
        del copied_config['examples']

        self.config = copied_config

        # Ensure the visualization takes all available width.
        display(HTML("<style>.container { width:100% !important; }</style>"))

    def _set_examples(self, examples):
        self.examples = [json_format.MessageToJson(ex) for ex in examples]
        self.updated_example_indices = set(range(len(examples)))
        self._generate_sprite()

    def json_to_proto(self, json):
        ex = (tf.train.SequenceExample() if
              self.config.get('are_sequence_examples') else tf.train.Example())
        json_format.Parse(json, ex)
        return ex

    @observe('infer')
    def _infer(self, change):
        indices_to_infer = sorted(self.updated_example_indices)
        examples_to_infer = [
            self.json_to_proto(self.examples[index])
            for index in indices_to_infer
        ]
        infer_objs = []
        serving_bundle = inference_utils.ServingBundle(
            self.config.get('inference_address'),
            self.config.get('model_name'), self.config.get('model_type'),
            self.config.get('model_version'),
            self.config.get('model_signature'),
            self.config.get('uses_predict_api'),
            self.config.get('predict_input_tensor'),
            self.config.get('predict_output_tensor'),
            self.estimator_and_spec.get('estimator'),
            self.estimator_and_spec.get('feature_spec'))
        infer_objs.append(
            inference_utils.run_inference_for_inference_results(
                examples_to_infer, serving_bundle))
        if ('inference_address_2' in self.config
                or self.compare_estimator_and_spec.get('estimator')):
            serving_bundle = inference_utils.ServingBundle(
                self.config.get('inference_address_2'),
                self.config.get('model_name_2'), self.config.get('model_type'),
                self.config.get('model_version_2'),
                self.config.get('model_signature_2'),
                self.config.get('uses_predict_api'),
                self.config.get('predict_input_tensor'),
                self.config.get('predict_output_tensor'),
                self.compare_estimator_and_spec.get('estimator'),
                self.compare_estimator_and_spec.get('feature_spec'))
            infer_objs.append(
                inference_utils.run_inference_for_inference_results(
                    examples_to_infer, serving_bundle))
        self.updated_example_indices = set()
        self.inferences = {
            'inferences': {
                'indices': indices_to_infer,
                'results': infer_objs
            },
            'label_vocab': self.config.get('label_vocab')
        }

    # Observer callbacks for changes from javascript.
    @observe('get_eligible_features')
    def _get_eligible_features(self, change):
        examples = [self.json_to_proto(ex) for ex in self.examples[0:50]]
        features_list = inference_utils.get_eligible_features(examples, 10)
        self.eligible_features = features_list

    @observe('infer_mutants')
    def _infer_mutants(self, change):
        info = self.infer_mutants
        example_index = int(info['example_index'])
        feature_name = info['feature_name']
        examples = (self.examples
                    if example_index == -1 else [self.examples[example_index]])
        examples = [self.json_to_proto(ex) for ex in examples]
        scan_examples = [self.json_to_proto(ex) for ex in self.examples[0:50]]
        serving_bundles = []
        serving_bundles.append(
            inference_utils.ServingBundle(
                self.config.get('inference_address'),
                self.config.get('model_name'), self.config.get('model_type'),
                self.config.get('model_version'),
                self.config.get('model_signature'),
                self.config.get('uses_predict_api'),
                self.config.get('predict_input_tensor'),
                self.config.get('predict_output_tensor'),
                self.estimator_and_spec.get('estimator'),
                self.estimator_and_spec.get('feature_spec')))
        if ('inference_address_2' in self.config
                or self.compare_estimator_and_spec.get('estimator')):
            serving_bundles.append(
                inference_utils.ServingBundle(
                    self.config.get('inference_address_2'),
                    self.config.get('model_name_2'),
                    self.config.get('model_type'),
                    self.config.get('model_version_2'),
                    self.config.get('model_signature_2'),
                    self.config.get('uses_predict_api'),
                    self.config.get('predict_input_tensor'),
                    self.config.get('predict_output_tensor'),
                    self.compare_estimator_and_spec.get('estimator'),
                    self.compare_estimator_and_spec.get('feature_spec')))
        viz_params = inference_utils.VizParams(info['x_min'], info['x_max'],
                                               scan_examples, 10,
                                               info['feature_index_pattern'])
        json_mapping = inference_utils.mutant_charts_for_feature(
            examples, feature_name, serving_bundles, viz_params)
        json_mapping['counter'] = self.mutant_charts_counter
        self.mutant_charts_counter += 1
        self.mutant_charts = json_mapping

    @observe('update_example')
    def _update_example(self, change):
        index = self.update_example['index']
        self.updated_example_indices.add(index)
        self.examples[index] = self.update_example['example']
        self._generate_sprite()

    @observe('duplicate_example')
    def _duplicate_example(self, change):
        self.examples.append(self.examples[self.duplicate_example['index']])
        self.updated_example_indices.add(len(self.examples) - 1)
        self._generate_sprite()

    @observe('delete_example')
    def _delete_example(self, change):
        index = self.delete_example['index']
        self.examples.pop(index)
        self.updated_example_indices = set(
            [i if i < index else i - 1 for i in self.updated_example_indices])
        self._generate_sprite()

    def _generate_sprite(self):
        # Generate a sprite image for the examples if the examples contain the
        # standard encoded image feature.
        if not self.examples:
            return
        example_to_check = self.json_to_proto(self.examples[0])
        feature_list = (example_to_check.context.feature
                        if self.config.get('are_sequence_examples') else
                        example_to_check.features.feature)
        if 'image/encoded' in feature_list:
            example_strings = [
                self.json_to_proto(ex).SerializeToString()
                for ex in self.examples
            ]
            encoded = base64.b64encode(
                inference_utils.create_sprite_image(example_strings))
            self.sprite = 'data:image/png;base64,{}'.format(encoded)
Example #13
0
class OpenShiftOAuthenticator(OAuthenticator):

    login_service = "OpenShift"

    scope = ['user:info']

    openshift_url = Unicode(
        os.environ.get('OPENSHIFT_URL')
        or 'https://openshift.default.svc.cluster.local',
        config=True,
    )

    validate_cert = Bool(True,
                         config=True,
                         help="Set to False to disable certificate validation")

    ca_certs = Unicode(config=True)

    allowed_groups = Set(
        config=True,
        help=
        "Set of OpenShift groups that should be allowed to access the hub.",
    )

    admin_groups = Set(
        config=True,
        help=
        "Set of OpenShift groups that should be given admin access to the hub.",
    )

    @default("ca_certs")
    def _ca_certs_default(self):
        ca_cert_file = "/run/secrets/kubernetes.io/serviceaccount/ca.crt"
        if self.validate_cert and os.path.exists(ca_cert_file):
            return ca_cert_file

        return ''

    openshift_auth_api_url = Unicode(config=True)

    @default("openshift_auth_api_url")
    def _openshift_auth_api_url_default(self):
        auth_info_url = '%s/.well-known/oauth-authorization-server' % self.openshift_url

        resp = requests.get(auth_info_url,
                            verify=self.ca_certs or self.validate_cert)
        resp_json = resp.json()

        return resp_json.get('issuer')

    openshift_rest_api_url = Unicode(
        os.environ.get('OPENSHIFT_REST_API_URL')
        or 'https://openshift.default.svc.cluster.local',
        config=True,
    )

    @default("openshift_rest_api_url")
    def _openshift_rest_api_url_default(self):
        return self.openshift_url

    @default("authorize_url")
    def _authorize_url_default(self):
        return "%s/oauth/authorize" % self.openshift_auth_api_url

    @default("token_url")
    def _token_url_default(self):
        return "%s/oauth/token" % self.openshift_auth_api_url

    @default("userdata_url")
    def _userdata_url_default(self):
        return "%s/apis/user.openshift.io/v1/users/~" % self.openshift_rest_api_url

    @staticmethod
    def user_in_groups(user_groups: set, allowed_groups: set):
        return any(user_groups.intersection(allowed_groups))

    async def authenticate(self, handler, data=None):
        code = handler.get_argument("code")
        # TODO: Configure the curl_httpclient for tornado

        http_client = AsyncHTTPClient()

        # Exchange the OAuth code for a OpenShift Access Token
        #
        # See: https://docs.openshift.org/latest/architecture/additional_concepts/authentication.html#api-authentication

        params = dict(
            client_id=self.client_id,
            client_secret=self.client_secret,
            grant_type="authorization_code",
            code=code,
        )

        url = url_concat(self.token_url, params)

        req = HTTPRequest(
            url,
            method="POST",
            validate_cert=self.validate_cert,
            ca_certs=self.ca_certs,
            headers={"Accept": "application/json"},
            body='',  # Body is required for a POST...
        )

        resp = await http_client.fetch(req)

        resp_json = json.loads(resp.body.decode('utf8', 'replace'))

        access_token = resp_json['access_token']

        # Determine who the logged in user is
        headers = {
            "Accept": "application/json",
            "User-Agent": "JupyterHub",
            "Authorization": "Bearer {}".format(access_token),
        }

        req = HTTPRequest(
            self.userdata_url,
            method="GET",
            validate_cert=self.validate_cert,
            ca_certs=self.ca_certs,
            headers=headers,
        )

        resp = await http_client.fetch(req)
        ocp_user = json.loads(resp.body.decode('utf8', 'replace'))

        username = ocp_user['metadata']['name']

        user_info = {
            'name': username,
            'auth_state': {
                'access_token': access_token,
                'openshift_user': ocp_user
            },
        }

        if self.allowed_groups or self.admin_groups:
            user_info = await self._add_openshift_group_info(user_info)

        return user_info

    async def _add_openshift_group_info(self, user_info: dict):
        """
        Use the group info stored on the OpenShift User object to determine if a user
        is authenticated based on groups, an admin, or both.
        """
        user_groups = set(user_info['auth_state']['openshift_user']['groups'])

        if self.admin_groups:
            is_admin = self.user_in_groups(user_groups, self.admin_groups)

        user_in_allowed_group = self.user_in_groups(user_groups,
                                                    self.allowed_groups)

        if self.admin_groups and (is_admin or user_in_allowed_group):
            user_info['admin'] = is_admin
            return user_info
        elif user_in_allowed_group:
            return user_info
        else:
            msg = "username:{username} User not in any of the allowed/admin groups"
            self.log.warning(msg.format(username=user_info['name']))
            return None
Example #14
0
class Authenticator(LoggingConfigurable):
    """Base class for implementing an authentication provider for JupyterHub"""

    db = Any()

    admin_users = Set(help="""
        Set of users that will have admin rights on this JupyterHub.

        Admin users have extra privilages:
         - Use the admin panel to see list of users logged in
         - Add / remove users in some authenticators
         - Restart / halt the hub
         - Start / stop users' single-user servers
         - Can access each individual users' single-user server (if configured)

        Admin access should be treated the same way root access is.

        Defaults to an empty set, in which case no user has admin access.
        """).tag(config=True)

    whitelist = Set(help="""
        Whitelist of usernames that are allowed to log in.

        Use this with supported authenticators to restrict which users can log in. This is an
        additional whitelist that further restricts users, beyond whatever restrictions the
        authenticator has in place.

        If empty, does not perform any additional restriction.
        """).tag(config=True)

    custom_html = Unicode(help="""
        HTML form to be overridden by authenticators if they want a custom authentication form.

        Defaults to an empty string, which shows the default username/password form.
        """)

    login_service = Unicode(help="""
        Name of the login service that this authenticator is providing using to authenticate users.

        Example: GitHub, MediaWiki, Google, etc.

        Setting this value replaces the login form with a "Login with <login_service>" button.

        Any authenticator that redirects to an external service (e.g. using OAuth) should set this.
        """)

    username_pattern = Unicode(help="""
        Regular expression pattern that all valid usernames must match.

        If a username does not match the pattern specified here, authentication will not be attempted.

        If not set, allow any username.
        """).tag(config=True)

    @observe('username_pattern')
    def _username_pattern_changed(self, change):
        if not change['new']:
            self.username_regex = None
        self.username_regex = re.compile(change['new'])

    username_regex = Any(help="""
        Compiled regex kept in sync with `username_pattern`
        """)

    def validate_username(self, username):
        """Validate a normalized username

        Return True if username is valid, False otherwise.
        """
        if not self.username_regex:
            return True
        return bool(self.username_regex.match(username))

    username_map = Dict(
        help="""Dictionary mapping authenticator usernames to JupyterHub users.

        Primarily used to normalize OAuth user names to local users.
        """).tag(config=True)

    def normalize_username(self, username):
        """Normalize the given username and return it

        Override in subclasses if usernames need different normalization rules.

        The default attempts to lowercase the username and apply `username_map` if it is
        set.
        """
        username = username.lower()
        username = self.username_map.get(username, username)
        return username

    def check_whitelist(self, username):
        """Check if a username is allowed to authenticate based on whitelist configuration

        Return True if username is allowed, False otherwise.
        No whitelist means any username is allowed.

        Names are normalized *before* being checked against the whitelist.
        """
        if not self.whitelist:
            # No whitelist means any name is allowed
            return True
        return username in self.whitelist

    @gen.coroutine
    def get_authenticated_user(self, handler, data):
        """Authenticate the user who is attempting to log in

        Returns normalized username if successful, None otherwise.

        This calls `authenticate`, which should be overridden in subclasses,
        normalizes the username if any normalization should be done,
        and then validates the name in the whitelist.

        This is the outer API for authenticating a user.
        Subclasses should not need to override this method.

        The various stages can be overridden separately:
         - `authenticate` turns formdata into a username
         - `normalize_username` normalizes the username
         - `check_whitelist` checks against the user whitelist
        """
        username = yield self.authenticate(handler, data)
        if username is None:
            return
        username = self.normalize_username(username)
        if not self.validate_username(username):
            self.log.warning("Disallowing invalid username %r.", username)
            return

        whitelist_pass = yield gen.maybe_future(self.check_whitelist(username))
        if whitelist_pass:
            return username
        else:
            self.log.warning("User %r not in whitelist.", username)
            return

    @gen.coroutine
    def authenticate(self, handler, data):
        """Authenticate a user with login form data

        This must be a tornado gen.coroutine.
        It must return the username on successful authentication,
        and return None on failed authentication.

        Checking the whitelist is handled separately by the caller.

        Args:
            handler (tornado.web.RequestHandler): the current request handler
            data (dict): The formdata of the login form.
                         The default form has 'username' and 'password' fields.
        Returns:
            username (str or None): The username of the authenticated user,
            or None if Authentication failed
        """

    def pre_spawn_start(self, user, spawner):
        """Hook called before spawning a user's server

        Can be used to do auth-related startup, e.g. opening PAM sessions.
        """

    def post_spawn_stop(self, user, spawner):
        """Hook called after stopping a user container

        Can be used to do auth-related cleanup, e.g. closing PAM sessions.
        """

    def add_user(self, user):
        """Hook called when a user is added to JupyterHub

        This is called:
         - When a user first authenticates
         - When the hub restarts, for all users.

        This method may be a coroutine.

        By default, this just adds the user to the whitelist.

        Subclasses may do more extensive things, such as adding actual unix users,
        but they should call super to ensure the whitelist is updated.

        Note that this should be idempotent, since it is called whenever the hub restarts
        for all users.

        Args:
            user (User): The User wrapper object
        """
        if not self.validate_username(user.name):
            raise ValueError("Invalid username: %s" % user.name)
        if self.whitelist:
            self.whitelist.add(user.name)

    def delete_user(self, user):
        """Hook called when a user is deleted

        Removes the user from the whitelist.
        Subclasses should call super to ensure the whitelist is updated.

        Args:
            user (User): The User wrapper object
        """
        self.whitelist.discard(user.name)

    def login_url(self, base_url):
        """Override this when registering a custom login handler

        Generally used by authenticators that do not use simple form based authentication.

        The subclass overriding this is responsible for making sure there is a handler
        available to handle the URL returned from this method, using the `get_handlers`
        method.

        Args:
            base_url (str): the base URL of the Hub (e.g. /hub/)

        Returns:
            str: The login URL, e.g. '/hub/login'
        """
        return url_path_join(base_url, 'login')

    def logout_url(self, base_url):
        """Override when registering a custom logout handler

        The subclass overriding this is responsible for making sure there is a handler
        available to handle the URL returned from this method, using the `get_handlers`
        method.

        Args:
            base_url (str): the base URL of the Hub (e.g. /hub/)

        Returns:
            str: The logout URL, e.g. '/hub/logout'
        """
        return url_path_join(base_url, 'logout')

    def get_handlers(self, app):
        """Return any custom handlers the authenticator needs to register

        Used in conjugation with `login_url` and `logout_url`.

        Args:
            app (JupyterHub Application):
                the application object, in case it needs to be accessed for info.
        Returns:
            handlers (list):
                list of ``('/url', Handler)`` tuples passed to tornado.
                The Hub prefix is added to any URLs.
        """
        return [
            ('/login', LoginHandler),
        ]
Example #15
0
class Widget(LoggingHasTraits):
    #-------------------------------------------------------------------------
    # Class attributes
    #-------------------------------------------------------------------------
    _widget_construction_callback = None

    # widgets is a dictionary of all active widget objects
    widgets = {}

    # widget_types is a registry of widgets by module, version, and name:
    widget_types = WidgetRegistry()

    @classmethod
    def close_all(cls):
        for widget in list(cls.widgets.values()):
            widget.close()


    @staticmethod
    def on_widget_constructed(callback):
        """Registers a callback to be called when a widget is constructed.

        The callback must have the following signature:
        callback(widget)"""
        Widget._widget_construction_callback = callback

    @staticmethod
    def _call_widget_constructed(widget):
        """Static method, called when a widget is constructed."""
        if Widget._widget_construction_callback is not None and callable(Widget._widget_construction_callback):
            Widget._widget_construction_callback(widget)

    @staticmethod
    def handle_comm_opened(comm, msg):
        """Static method, called when a widget is constructed."""
        version = msg.get('metadata', {}).get('version', '')
        if version.split('.')[0] != PROTOCOL_VERSION_MAJOR:
            raise ValueError("Incompatible widget protocol versions: received version %r, expected version %r"%(version, __protocol_version__))
        data = msg['content']['data']
        state = data['state']

        # Find the widget class to instantiate in the registered widgets
        widget_class = Widget.widget_types.get(state['_model_module'],
                                               state['_model_module_version'],
                                               state['_model_name'],
                                               state['_view_module'],
                                               state['_view_module_version'],
                                               state['_view_name'])
        widget = widget_class(comm=comm)
        if 'buffer_paths' in data:
            _put_buffers(state, data['buffer_paths'], msg['buffers'])
        widget.set_state(state)

    @staticmethod
    def get_manager_state(drop_defaults=False, widgets=None):
        """Returns the full state for a widget manager for embedding

        :param drop_defaults: when True, it will not include default value
        :param widgets: list with widgets to include in the state (or all widgets when None)
        :return:
        """
        state = {}
        if widgets is None:
            widgets = Widget.widgets.values()
        for widget in widgets:
            state[widget.model_id] = widget._get_embed_state(drop_defaults=drop_defaults)
        return {'version_major': 2, 'version_minor': 0, 'state': state}


    def _get_embed_state(self, drop_defaults=False):
        state = {
            'model_name': self._model_name,
            'model_module': self._model_module,
            'model_module_version': self._model_module_version
        }
        model_state, buffer_paths, buffers = _remove_buffers(self.get_state(drop_defaults=drop_defaults))
        state['state'] = model_state
        if len(buffers) > 0:
            state['buffers'] = [{'encoding': 'base64',
                                 'path': p,
                                 'data': standard_b64encode(d).decode('ascii')}
                                for p, d in zip(buffer_paths, buffers)]
        return state

    def get_view_spec(self):
        return dict(version_major=2, version_minor=0, model_id=self._model_id)

    #-------------------------------------------------------------------------
    # Traits
    #-------------------------------------------------------------------------
    _model_name = Unicode('WidgetModel',
        help="Name of the model.", read_only=True).tag(sync=True)
    _model_module = Unicode('@jupyter-widgets/base',
        help="The namespace for the model.", read_only=True).tag(sync=True)
    _model_module_version = Unicode(__jupyter_widgets_base_version__,
        help="A semver requirement for namespace version containing the model.", read_only=True).tag(sync=True)
    _view_name = Unicode(None, allow_none=True,
        help="Name of the view.").tag(sync=True)
    _view_module = Unicode(None, allow_none=True,
        help="The namespace for the view.").tag(sync=True)
    _view_module_version = Unicode('',
        help="A semver requirement for the namespace version containing the view.").tag(sync=True)

    _view_count = Int(None, allow_none=True,
        help="EXPERIMENTAL: The number of views of the model displayed in the frontend. This attribute is experimental and may change or be removed in the future. None signifies that views will not be tracked. Set this to 0 to start tracking view creation/deletion.").tag(sync=True)
    comm = Instance('ipykernel.comm.Comm', allow_none=True)

    keys = List(help="The traits which are synced.")

    @default('keys')
    def _default_keys(self):
        return [name for name in self.traits(sync=True)]

    _property_lock = Dict()
    _holding_sync = False
    _states_to_send = Set()
    _display_callbacks = Instance(CallbackDispatcher, ())
    _msg_callbacks = Instance(CallbackDispatcher, ())

    #-------------------------------------------------------------------------
    # (Con/de)structor
    #-------------------------------------------------------------------------
    def __init__(self, **kwargs):
        """Public constructor"""
        self._model_id = kwargs.pop('model_id', None)
        super(Widget, self).__init__(**kwargs)

        Widget._call_widget_constructed(self)
        self.open()

    def __del__(self):
        """Object disposal"""
        self.close()

    #-------------------------------------------------------------------------
    # Properties
    #-------------------------------------------------------------------------

    def open(self):
        """Open a comm to the frontend if one isn't already open."""
        if self.comm is None:
            state, buffer_paths, buffers = _remove_buffers(self.get_state())

            args = dict(target_name='jupyter.widget',
                        data={'state': state, 'buffer_paths': buffer_paths},
                        buffers=buffers,
                        metadata={'version': __protocol_version__}
                        )
            if self._model_id is not None:
                args['comm_id'] = self._model_id

            self.comm = Comm(**args)

    @observe('comm')
    def _comm_changed(self, change):
        """Called when the comm is changed."""
        if change['new'] is None:
            return
        self._model_id = self.model_id

        self.comm.on_msg(self._handle_msg)
        Widget.widgets[self.model_id] = self

    @property
    def model_id(self):
        """Gets the model id of this widget.

        If a Comm doesn't exist yet, a Comm will be created automagically."""
        return self.comm.comm_id

    #-------------------------------------------------------------------------
    # Methods
    #-------------------------------------------------------------------------

    def close(self):
        """Close method.

        Closes the underlying comm.
        When the comm is closed, all of the widget views are automatically
        removed from the front-end."""
        if self.comm is not None:
            Widget.widgets.pop(self.model_id, None)
            self.comm.close()
            self.comm = None
            self._ipython_display_ = None

    def send_state(self, key=None):
        """Sends the widget state, or a piece of it, to the front-end, if it exists.

        Parameters
        ----------
        key : unicode, or iterable (optional)
            A single property's name or iterable of property names to sync with the front-end.
        """
        state = self.get_state(key=key)
        if len(state) > 0:
            state, buffer_paths, buffers = _remove_buffers(state)
            msg = {'method': 'update', 'state': state, 'buffer_paths': buffer_paths}
            self._send(msg, buffers=buffers)

    def get_state(self, key=None, drop_defaults=False):
        """Gets the widget state, or a piece of it.

        Parameters
        ----------
        key : unicode or iterable (optional)
            A single property's name or iterable of property names to get.

        Returns
        -------
        state : dict of states
        metadata : dict
            metadata for each field: {key: metadata}
        """
        if key is None:
            keys = self.keys
        elif isinstance(key, string_types):
            keys = [key]
        elif isinstance(key, collections.Iterable):
            keys = key
        else:
            raise ValueError("key must be a string, an iterable of keys, or None")
        state = {}
        traits = self.traits()
        for k in keys:
            to_json = self.trait_metadata(k, 'to_json', self._trait_to_json)
            value = to_json(getattr(self, k), self)
            if not PY3 and isinstance(traits[k], Bytes) and isinstance(value, bytes):
                value = memoryview(value)
            if not drop_defaults or not self._compare(value, traits[k].default_value):
                state[k] = value
        return state

    def _is_numpy(self, x):
        return x.__class__.__name__ == 'ndarray' and x.__class__.__module__ == 'numpy'

    def _compare(self, a, b):
        if self._is_numpy(a) or self._is_numpy(b):
            import numpy as np
            return np.array_equal(a, b)
        else:
            return a == b

    def set_state(self, sync_data):
        """Called when a state is received from the front-end."""
        # The order of these context managers is important. Properties must
        # be locked when the hold_trait_notification context manager is
        # released and notifications are fired.
        with self._lock_property(**sync_data), self.hold_trait_notifications():
            for name in sync_data:
                if name in self.keys:
                    from_json = self.trait_metadata(name, 'from_json',
                                                    self._trait_from_json)
                    self.set_trait(name, from_json(sync_data[name], self))

    def send(self, content, buffers=None):
        """Sends a custom msg to the widget model in the front-end.

        Parameters
        ----------
        content : dict
            Content of the message to send.
        buffers : list of binary buffers
            Binary buffers to send with message
        """
        self._send({"method": "custom", "content": content}, buffers=buffers)

    def on_msg(self, callback, remove=False):
        """(Un)Register a custom msg receive callback.

        Parameters
        ----------
        callback: callable
            callback will be passed three arguments when a message arrives::

                callback(widget, content, buffers)

        remove: bool
            True if the callback should be unregistered."""
        self._msg_callbacks.register_callback(callback, remove=remove)

    def on_displayed(self, callback, remove=False):
        """(Un)Register a widget displayed callback.

        Parameters
        ----------
        callback: method handler
            Must have a signature of::

                callback(widget, **kwargs)

            kwargs from display are passed through without modification.
        remove: bool
            True if the callback should be unregistered."""
        self._display_callbacks.register_callback(callback, remove=remove)

    def add_traits(self, **traits):
        """Dynamically add trait attributes to the Widget."""
        super(Widget, self).add_traits(**traits)
        for name, trait in traits.items():
            if trait.get_metadata('sync'):
                self.keys.append(name)
                self.send_state(name)

    def notify_change(self, change):
        """Called when a property has changed."""
        # Send the state to the frontend before the user-registered callbacks
        # are called.
        name = change['name']
        if self.comm is not None and self.comm.kernel is not None:
            # Make sure this isn't information that the front-end just sent us.
            if name in self.keys and self._should_send_property(name, getattr(self, name)):
                # Send new state to front-end
                self.send_state(key=name)
        super(Widget, self).notify_change(change)

    def __repr__(self):
        return self._gen_repr_from_keys(self._repr_keys())

    #-------------------------------------------------------------------------
    # Support methods
    #-------------------------------------------------------------------------
    @contextmanager
    def _lock_property(self, **properties):
        """Lock a property-value pair.

        The value should be the JSON state of the property.

        NOTE: This, in addition to the single lock for all state changes, is
        flawed.  In the future we may want to look into buffering state changes
        back to the front-end."""
        self._property_lock = properties
        try:
            yield
        finally:
            self._property_lock = {}

    @contextmanager
    def hold_sync(self):
        """Hold syncing any state until the outermost context manager exits"""
        if self._holding_sync is True:
            yield
        else:
            try:
                self._holding_sync = True
                yield
            finally:
                self._holding_sync = False
                self.send_state(self._states_to_send)
                self._states_to_send.clear()

    def _should_send_property(self, key, value):
        """Check the property lock (property_lock)"""
        to_json = self.trait_metadata(key, 'to_json', self._trait_to_json)
        if key in self._property_lock:
            # model_state, buffer_paths, buffers
            split_value = _remove_buffers({ key: to_json(value, self)})
            split_lock = _remove_buffers({ key: self._property_lock[key]})
            # A roundtrip conversion through json in the comparison takes care of
            # idiosyncracies of how python data structures map to json, for example
            # tuples get converted to lists.
            if (jsonloads(jsondumps(split_value[0])) == split_lock[0]
                and split_value[1] == split_lock[1]
                and _buffer_list_equal(split_value[2], split_lock[2])):
                return False
        if self._holding_sync:
            self._states_to_send.add(key)
            return False
        else:
            return True

    # Event handlers
    @_show_traceback
    def _handle_msg(self, msg):
        """Called when a msg is received from the front-end"""
        data = msg['content']['data']
        method = data['method']

        if method == 'update':
            if 'state' in data:
                state = data['state']
                if 'buffer_paths' in data:
                    _put_buffers(state, data['buffer_paths'], msg['buffers'])
                self.set_state(state)

        # Handle a state request.
        elif method == 'request_state':
            self.send_state()

        # Handle a custom msg from the front-end.
        elif method == 'custom':
            if 'content' in data:
                self._handle_custom_msg(data['content'], msg['buffers'])

        # Catch remainder.
        else:
            self.log.error('Unknown front-end to back-end widget msg with method "%s"' % method)

    def _handle_custom_msg(self, content, buffers):
        """Called when a custom msg is received."""
        self._msg_callbacks(self, content, buffers)

    def _handle_displayed(self, **kwargs):
        """Called when a view has been displayed for this widget instance"""
        self._display_callbacks(self, **kwargs)

    @staticmethod
    def _trait_to_json(x, self):
        """Convert a trait value to json."""
        return x

    @staticmethod
    def _trait_from_json(x, self):
        """Convert json values to objects."""
        return x

    def _ipython_display_(self, **kwargs):
        """Called when `IPython.display.display` is called on the widget."""
        if self._view_name is not None:

            # The 'application/vnd.jupyter.widget-view+json' mimetype has not been registered yet.
            # See the registration process and naming convention at
            # http://tools.ietf.org/html/rfc6838
            # and the currently registered mimetypes at
            # http://www.iana.org/assignments/media-types/media-types.xhtml.
            data = {
                'text/plain': repr(self),
                'text/html': self._fallback_html(),
                'application/vnd.jupyter.widget-view+json': {
                    'version_major': 2,
                    'version_minor': 0,
                    'model_id': self._model_id
                }
            }
            display(data, raw=True)

            self._handle_displayed(**kwargs)

    def _send(self, msg, buffers=None):
        """Sends a message to the model in the front-end."""
        if self.comm is not None and self.comm.kernel is not None:
            self.comm.send(data=msg, buffers=buffers)

    def _repr_keys(self):
        traits = self.traits()
        for key in sorted(self.keys):
            # Exclude traits that start with an underscore
            if key[0] == '_':
                continue
            # Exclude traits who are equal to their default value
            value = getattr(self, key)
            trait = traits[key]
            if self._compare(value, trait.default_value):
                continue
            elif (isinstance(trait, (Container, Dict)) and
                  trait.default_value == Undefined and
                  (value is None or len(value) == 0)):
                # Empty container, and dynamic default will be empty
                continue
            yield key

    def _gen_repr_from_keys(self, keys):
        class_name = self.__class__.__name__
        signature = ', '.join(
            '%s=%r' % (key, getattr(self, key))
            for key in keys
        )
        return '%s(%s)' % (class_name, signature)

    def _fallback_html(self):
        return _FALLBACK_HTML_TEMPLATE.format(widget_type=type(self).__name__)
Example #16
0
class Authenticator(LoggingConfigurable):
    """Base class for implementing an authentication provider for JupyterHub"""

    db = Any()

    enable_auth_state = Bool(
        False,
        config=True,
        help="""Enable persisting auth_state (if available).

        auth_state will be encrypted and stored in the Hub's database.
        This can include things like authentication tokens, etc.
        to be passed to Spawners as environment variables.

        Encrypting auth_state requires the cryptography package.

        Additionally, the JUPYTERHUB_CRYPT_KEY environment variable must
        contain one (or more, separated by ;) 32B encryption keys.
        These can be either base64 or hex-encoded.

        If encryption is unavailable, auth_state cannot be persisted.

        New in JupyterHub 0.8
        """,
    )

    auth_refresh_age = Integer(
        300,
        config=True,
        help="""The max age (in seconds) of authentication info
        before forcing a refresh of user auth info.

        Refreshing auth info allows, e.g. requesting/re-validating auth tokens.

        See :meth:`.refresh_user` for what happens when user auth info is refreshed
        (nothing by default).
        """)

    refresh_pre_spawn = Bool(False,
                             config=True,
                             help="""Force refresh of auth prior to spawn.

        This forces :meth:`.refresh_user` to be called prior to launching
        a server, to ensure that auth state is up-to-date.

        This can be important when e.g. auth tokens that may have expired
        are passed to the spawner via environment variables from auth_state.

        If refresh_user cannot refresh the user auth data,
        launch will fail until the user logs in again.
        """)

    admin_users = Set(help="""
        Set of users that will have admin rights on this JupyterHub.

        Admin users have extra privileges:
         - Use the admin panel to see list of users logged in
         - Add / remove users in some authenticators
         - Restart / halt the hub
         - Start / stop users' single-user servers
         - Can access each individual users' single-user server (if configured)

        Admin access should be treated the same way root access is.

        Defaults to an empty set, in which case no user has admin access.
        """).tag(config=True)

    whitelist = Set(help="""
        Whitelist of usernames that are allowed to log in.

        Use this with supported authenticators to restrict which users can log in. This is an
        additional whitelist that further restricts users, beyond whatever restrictions the
        authenticator has in place.

        If empty, does not perform any additional restriction.
        """).tag(config=True)

    blacklist = Set(help="""
        Blacklist of usernames that are not allowed to log in.

        Use this with supported authenticators to restrict which users can not log in. This is an
        additional blacklist that further restricts users, beyond whatever restrictions the
        authenticator has in place.

        If empty, does not perform any additional restriction.

        .. versionadded: 0.9
        """).tag(config=True)

    @observe('whitelist')
    def _check_whitelist(self, change):
        short_names = [name for name in change['new'] if len(name) <= 1]
        if short_names:
            sorted_names = sorted(short_names)
            single = ''.join(sorted_names)
            string_set_typo = "set('%s')" % single
            self.log.warning(
                "whitelist contains single-character names: %s; did you mean set([%r]) instead of %s?",
                sorted_names[:8],
                single,
                string_set_typo,
            )

    custom_html = Unicode(help="""
        HTML form to be overridden by authenticators if they want a custom authentication form.

        Defaults to an empty string, which shows the default username/password form.
        """)

    login_service = Unicode(help="""
        Name of the login service that this authenticator is providing using to authenticate users.

        Example: GitHub, MediaWiki, Google, etc.

        Setting this value replaces the login form with a "Login with <login_service>" button.

        Any authenticator that redirects to an external service (e.g. using OAuth) should set this.
        """)

    username_pattern = Unicode(help="""
        Regular expression pattern that all valid usernames must match.

        If a username does not match the pattern specified here, authentication will not be attempted.

        If not set, allow any username.
        """).tag(config=True)

    @observe('username_pattern')
    def _username_pattern_changed(self, change):
        if not change['new']:
            self.username_regex = None
        self.username_regex = re.compile(change['new'])

    username_regex = Any(help="""
        Compiled regex kept in sync with `username_pattern`
        """)

    def validate_username(self, username):
        """Validate a normalized username

        Return True if username is valid, False otherwise.
        """
        if '/' in username:
            # / is not allowed in usernames
            return False
        if not username:
            # empty usernames are not allowed
            return False
        if not self.username_regex:
            return True
        return bool(self.username_regex.match(username))

    username_map = Dict(
        help="""Dictionary mapping authenticator usernames to JupyterHub users.

        Primarily used to normalize OAuth user names to local users.
        """).tag(config=True)

    delete_invalid_users = Bool(
        False,
        help="""Delete any users from the database that do not pass validation

        When JupyterHub starts, `.add_user` will be called
        on each user in the database to verify that all users are still valid.

        If `delete_invalid_users` is True,
        any users that do not pass validation will be deleted from the database.
        Use this if users might be deleted from an external system,
        such as local user accounts.

        If False (default), invalid users remain in the Hub's database
        and a warning will be issued.
        This is the default to avoid data loss due to config changes.
        """)

    def normalize_username(self, username):
        """Normalize the given username and return it

        Override in subclasses if usernames need different normalization rules.

        The default attempts to lowercase the username and apply `username_map` if it is
        set.
        """
        username = username.lower()
        username = self.username_map.get(username, username)
        return username

    def check_whitelist(self, username):
        """Check if a username is allowed to authenticate based on whitelist configuration

        Return True if username is allowed, False otherwise.
        No whitelist means any username is allowed.

        Names are normalized *before* being checked against the whitelist.
        """
        if not self.whitelist:
            # No whitelist means any name is allowed
            return True
        return username in self.whitelist

    def check_blacklist(self, username):
        """Check if a username is blocked to authenticate based on blacklist configuration

        Return True if username is allowed, False otherwise.
        No blacklist means any username is allowed.

        Names are normalized *before* being checked against the blacklist.

        .. versionadded: 0.9
        """
        if not self.blacklist:
            # No blacklist means any name is allowed
            return True
        return username not in self.blacklist

    async def get_authenticated_user(self, handler, data):
        """Authenticate the user who is attempting to log in

        Returns user dict if successful, None otherwise.

        This calls `authenticate`, which should be overridden in subclasses,
        normalizes the username if any normalization should be done,
        and then validates the name in the whitelist.

        This is the outer API for authenticating a user.
        Subclasses should not override this method.

        The various stages can be overridden separately:
         - `authenticate` turns formdata into a username
         - `normalize_username` normalizes the username
         - `check_whitelist` checks against the user whitelist

        .. versionchanged:: 0.8
            return dict instead of username
        """
        authenticated = await maybe_future(self.authenticate(handler, data))
        if authenticated is None:
            return
        if isinstance(authenticated, dict):
            if 'name' not in authenticated:
                raise ValueError("user missing a name: %r" % authenticated)
        else:
            authenticated = {
                'name': authenticated,
            }
        authenticated.setdefault('auth_state', None)
        # Leave the default as None, but reevaluate later post-whitelist
        authenticated.setdefault('admin', None)

        # normalize the username
        authenticated['name'] = username = self.normalize_username(
            authenticated['name'])
        if not self.validate_username(username):
            self.log.warning("Disallowing invalid username %r.", username)
            return

        blacklist_pass = await maybe_future(self.check_blacklist(username))
        whitelist_pass = await maybe_future(self.check_whitelist(username))
        if blacklist_pass:
            pass
        else:
            self.log.warning("User %r in blacklist. Stop authentication",
                             username)
            return

        if whitelist_pass:
            if authenticated['admin'] is None:
                authenticated['admin'] = await maybe_future(
                    self.is_admin(handler, authenticated))

            return authenticated
        else:
            self.log.warning("User %r not in whitelist.", username)
            return

    async def refresh_user(self, user, handler=None):
        """Refresh auth data for a given user

        Allows refreshing or invalidating auth data.

        Only override if your authenticator needs
        to refresh its data about users once in a while.

        .. versionadded: 1.0

        Args:
            user (User): the user to refresh
            handler (tornado.web.RequestHandler or None): the current request handler
        Returns:
            auth_data (bool or dict):
                Return **True** if auth data for the user is up-to-date
                and no updates are required.

                Return **False** if the user's auth data has expired,
                and they should be required to login again.

                Return a **dict** of auth data if some values should be updated.
                This dict should have the same structure as that returned
                by :meth:`.authenticate()` when it returns a dict.
                Any fields present will refresh the value for the user.
                Any fields not present will be left unchanged.
                This can include updating `.admin` or `.auth_state` fields.
        """
        return True

    def is_admin(self, handler, authentication):
        """Authentication helper to determine a user's admin status.

        .. versionadded: 1.0

        Args:
            handler (tornado.web.RequestHandler): the current request handler
            authentication: The authetication dict generated by `authenticate`.
        Returns:
            admin_status (Bool or None):
                The admin status of the user, or None if it could not be 
                determined or should not change.
        """
        return True if authentication['name'] in self.admin_users else None

    async def authenticate(self, handler, data):
        """Authenticate a user with login form data

        This must be a coroutine.

        It must return the username on successful authentication,
        and return None on failed authentication.

        Checking the whitelist is handled separately by the caller.

        .. versionchanged:: 0.8
            Allow `authenticate` to return a dict containing auth_state.

        Args:
            handler (tornado.web.RequestHandler): the current request handler
            data (dict): The formdata of the login form.
                         The default form has 'username' and 'password' fields.
        Returns:
            user (str or dict or None):
                The username of the authenticated user,
                or None if Authentication failed.

                The Authenticator may return a dict instead, which MUST have a
                key `name` holding the username, and MAY have two optional keys
                set: `auth_state`, a dictionary of of auth state that will be
                persisted; and `admin`, the admin setting value for the user.
        """

    def pre_spawn_start(self, user, spawner):
        """Hook called before spawning a user's server

        Can be used to do auth-related startup, e.g. opening PAM sessions.
        """

    def post_spawn_stop(self, user, spawner):
        """Hook called after stopping a user container

        Can be used to do auth-related cleanup, e.g. closing PAM sessions.
        """

    def add_user(self, user):
        """Hook called when a user is added to JupyterHub

        This is called:
         - When a user first authenticates
         - When the hub restarts, for all users.

        This method may be a coroutine.

        By default, this just adds the user to the whitelist.

        Subclasses may do more extensive things, such as adding actual unix users,
        but they should call super to ensure the whitelist is updated.

        Note that this should be idempotent, since it is called whenever the hub restarts
        for all users.

        Args:
            user (User): The User wrapper object
        """
        if not self.validate_username(user.name):
            raise ValueError("Invalid username: %s" % user.name)
        if self.whitelist:
            self.whitelist.add(user.name)

    def delete_user(self, user):
        """Hook called when a user is deleted

        Removes the user from the whitelist.
        Subclasses should call super to ensure the whitelist is updated.

        Args:
            user (User): The User wrapper object
        """
        self.whitelist.discard(user.name)

    auto_login = Bool(False,
                      config=True,
                      help="""Automatically begin the login process

        rather than starting with a "Login with..." link at `/hub/login`

        To work, `.login_url()` must give a URL other than the default `/hub/login`,
        such as an oauth handler or another automatic login handler,
        registered with `.get_handlers()`.

        .. versionadded:: 0.8
        """)

    def login_url(self, base_url):
        """Override this when registering a custom login handler

        Generally used by authenticators that do not use simple form-based authentication.

        The subclass overriding this is responsible for making sure there is a handler
        available to handle the URL returned from this method, using the `get_handlers`
        method.

        Args:
            base_url (str): the base URL of the Hub (e.g. /hub/)

        Returns:
            str: The login URL, e.g. '/hub/login'
        """
        return url_path_join(base_url, 'login')

    def logout_url(self, base_url):
        """Override when registering a custom logout handler

        The subclass overriding this is responsible for making sure there is a handler
        available to handle the URL returned from this method, using the `get_handlers`
        method.

        Args:
            base_url (str): the base URL of the Hub (e.g. /hub/)

        Returns:
            str: The logout URL, e.g. '/hub/logout'
        """
        return url_path_join(base_url, 'logout')

    def get_handlers(self, app):
        """Return any custom handlers the authenticator needs to register

        Used in conjugation with `login_url` and `logout_url`.

        Args:
            app (JupyterHub Application):
                the application object, in case it needs to be accessed for info.
        Returns:
            handlers (list):
                list of ``('/url', Handler)`` tuples passed to tornado.
                The Hub prefix is added to any URLs.
        """
        return [
            ('/login', LoginHandler),
        ]
Example #17
0
class Widget(LoggingConfigurable):
    #-------------------------------------------------------------------------
    # Class attributes
    #-------------------------------------------------------------------------
    _widget_construction_callback = None
    _read_only_enabled = True
    widgets = {}
    widget_types = {}

    @staticmethod
    def on_widget_constructed(callback):
        """Registers a callback to be called when a widget is constructed.

        The callback must have the following signature:
        callback(widget)"""
        Widget._widget_construction_callback = callback

    @staticmethod
    def _call_widget_constructed(widget):
        """Static method, called when a widget is constructed."""
        if Widget._widget_construction_callback is not None and callable(Widget._widget_construction_callback):
            Widget._widget_construction_callback(widget)

    @staticmethod
    def handle_comm_opened(comm, msg):
        """Static method, called when a widget is constructed."""
        widget_class = import_item(str(msg['content']['data']['widget_class']))
        widget = widget_class(comm=comm)


    #-------------------------------------------------------------------------
    # Traits
    #-------------------------------------------------------------------------
    _model_module = Unicode(None, allow_none=True, help="""A requirejs module name
        in which to find _model_name. If empty, look in the global registry.""")
    _model_name = Unicode('WidgetModel', help="""Name of the backbone model 
        registered in the front-end to create and sync this widget with.""")
    _view_module = Unicode(help="""A requirejs module in which to find _view_name.
        If empty, look in the global registry.""", sync=True)
    _view_name = Unicode(None, allow_none=True, help="""Default view registered in the front-end
        to use to represent the widget.""", sync=True)
    comm = Instance('ipykernel.comm.Comm', allow_none=True)
    
    msg_throttle = Int(3, sync=True, help="""Maximum number of msgs the 
        front-end can send before receiving an idle msg from the back-end.""")
    
    version = Int(0, sync=True, help="""Widget's version""")
    keys = List()
    def _keys_default(self):
        return [name for name in self.traits(sync=True)]
    
    _property_lock = Dict()
    _holding_sync = False
    _states_to_send = Set()
    _display_callbacks = Instance(CallbackDispatcher, ())
    _msg_callbacks = Instance(CallbackDispatcher, ())
    
    #-------------------------------------------------------------------------
    # (Con/de)structor
    #-------------------------------------------------------------------------
    def __init__(self, **kwargs):
        """Public constructor"""
        self._model_id = kwargs.pop('model_id', None)
        super(Widget, self).__init__(**kwargs)

        Widget._call_widget_constructed(self)
        self.open()

    def __del__(self):
        """Object disposal"""
        self.close()

    #-------------------------------------------------------------------------
    # Properties
    #-------------------------------------------------------------------------

    def open(self):
        """Open a comm to the frontend if one isn't already open."""
        if self.comm is None:
            args = dict(target_name='ipython.widget',
                        data={'model_name': self._model_name,
                              'model_module': self._model_module})
            if self._model_id is not None:
                args['comm_id'] = self._model_id
            self.comm = Comm(**args)

    def _comm_changed(self, name, new):
        """Called when the comm is changed."""
        if new is None:
            return
        self._model_id = self.model_id
        
        self.comm.on_msg(self._handle_msg)
        Widget.widgets[self.model_id] = self
        
        # first update
        self.send_state()

    @property
    def model_id(self):
        """Gets the model id of this widget.

        If a Comm doesn't exist yet, a Comm will be created automagically."""
        return self.comm.comm_id

    #-------------------------------------------------------------------------
    # Methods
    #-------------------------------------------------------------------------

    def __setattr__(self, name, value):
        """Overload of HasTraits.__setattr__to handle read-only-ness of widget
        attributes """
        if (self._read_only_enabled and self.has_trait(name) and
            self.trait_metadata(name, 'read_only')): 
            raise TraitError('Widget attribute "%s" is read-only.' % name)
        else:
            super(Widget, self).__setattr__(name, value)


    def close(self):
        """Close method.

        Closes the underlying comm.
        When the comm is closed, all of the widget views are automatically
        removed from the front-end."""
        if self.comm is not None:
            Widget.widgets.pop(self.model_id, None)
            self.comm.close()
            self.comm = None
    
    def send_state(self, key=None):
        """Sends the widget state, or a piece of it, to the front-end.

        Parameters
        ----------
        key : unicode, or iterable (optional)
            A single property's name or iterable of property names to sync with the front-end.
        """
        state = self.get_state(key=key)
        buffer_keys, buffers = [], []
        for k, v in state.items():
            if isinstance(v, memoryview):
                state.pop(k)
                buffers.append(v)
                buffer_keys.append(k)
        msg = {'method': 'update', 'state': state, 'buffers': buffer_keys}
        self._send(msg, buffers=buffers)

    def get_state(self, key=None):
        """Gets the widget state, or a piece of it.

        Parameters
        ----------
        key : unicode or iterable (optional)
            A single property's name or iterable of property names to get.

        Returns
        -------
        state : dict of states
        metadata : dict
            metadata for each field: {key: metadata}
        """
        if key is None:
            keys = self.keys
        elif isinstance(key, string_types):
            keys = [key]
        elif isinstance(key, collections.Iterable):
            keys = key
        else:
            raise ValueError("key must be a string, an iterable of keys, or None")
        state = {}
        for k in keys:
            to_json = self.trait_metadata(k, 'to_json', self._trait_to_json)
            state[k] = to_json(getattr(self, k), self)
        return state

    def set_state(self, sync_data):
        """Called when a state is received from the front-end."""
        # The order of these context managers is important. Properties must
        # be locked when the hold_trait_notification context manager is
        # released and notifications are fired.
        with self._allow_write(),\
             self._lock_property(**sync_data),\
             self.hold_trait_notifications():
            for name in sync_data:
                if name in self.keys:
                    from_json = self.trait_metadata(name, 'from_json',
                                                    self._trait_from_json)
                    setattr(self, name, from_json(sync_data[name], self))

    def send(self, content, buffers=None):
        """Sends a custom msg to the widget model in the front-end.

        Parameters
        ----------
        content : dict
            Content of the message to send.
        buffers : list of binary buffers
            Binary buffers to send with message
        """
        self._send({"method": "custom", "content": content}, buffers=buffers)

    def on_msg(self, callback, remove=False):
        """(Un)Register a custom msg receive callback.

        Parameters
        ----------
        callback: callable
            callback will be passed three arguments when a message arrives::
            
                callback(widget, content, buffers)
            
        remove: bool
            True if the callback should be unregistered."""
        self._msg_callbacks.register_callback(callback, remove=remove)

    def on_displayed(self, callback, remove=False):
        """(Un)Register a widget displayed callback.

        Parameters
        ----------
        callback: method handler
            Must have a signature of::
            
                callback(widget, **kwargs)
            
            kwargs from display are passed through without modification.
        remove: bool
            True if the callback should be unregistered."""
        self._display_callbacks.register_callback(callback, remove=remove)

    def add_traits(self, **traits):
        """Dynamically add trait attributes to the Widget."""
        super(Widget, self).add_traits(**traits)
        for name, trait in traits.items():
            if trait.get_metadata('sync'):
                 self.keys.append(name)
                 self.send_state(name)

    #-------------------------------------------------------------------------
    # Support methods
    #-------------------------------------------------------------------------
    @contextmanager
    def _lock_property(self, **properties):
        """Lock a property-value pair.

        The value should be the JSON state of the property.

        NOTE: This, in addition to the single lock for all state changes, is
        flawed.  In the future we may want to look into buffering state changes 
        back to the front-end."""
        self._property_lock = properties
        try:
            yield
        finally:
            self._property_lock = {}

    @contextmanager
    def _allow_write(self):
        if self._read_only_enabled is False:
            yield
        else:
            try:
                self._read_only_enabled = False
                yield
            finally:
                self._read_only_enabled = True 

    @contextmanager
    def hold_sync(self):
        """Hold syncing any state until the outermost context manager exits"""
        if self._holding_sync is True:
            yield
        else:
            try:
                self._holding_sync = True
                yield
            finally:
                self._holding_sync = False     
                self.send_state(self._states_to_send)
                self._states_to_send.clear()

    def _should_send_property(self, key, value):
        """Check the property lock (property_lock)"""
        to_json = self.trait_metadata(key, 'to_json', self._trait_to_json)
        if (key in self._property_lock
            and to_json(value, self) == self._property_lock[key]):
            return False
        elif self._holding_sync:
            self._states_to_send.add(key)
            return False
        else:
            return True
    
    # Event handlers
    @_show_traceback
    def _handle_msg(self, msg):
        """Called when a msg is received from the front-end"""
        data = msg['content']['data']
        method = data['method']

        # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one.
        if method == 'backbone':
            if 'sync_data' in data:
                # get binary buffers too
                sync_data = data['sync_data']
                for i,k in enumerate(data.get('buffer_keys', [])):
                    sync_data[k] = msg['buffers'][i]
                self.set_state(sync_data) # handles all methods

        # Handle a state request.
        elif method == 'request_state':
            self.send_state()

        # Handle a custom msg from the front-end.
        elif method == 'custom':
            if 'content' in data:
                self._handle_custom_msg(data['content'], msg['buffers'])

        # Catch remainder.
        else:
            self.log.error('Unknown front-end to back-end widget msg with method "%s"' % method)

    def _handle_custom_msg(self, content, buffers):
        """Called when a custom msg is received."""
        self._msg_callbacks(self, content, buffers)

    def _notify_trait(self, name, old_value, new_value):
        """Called when a property has been changed."""
        # Trigger default traitlet callback machinery.  This allows any user
        # registered validation to be processed prior to allowing the widget
        # machinery to handle the state.
        LoggingConfigurable._notify_trait(self, name, old_value, new_value)

        # Send the state after the user registered callbacks for trait changes
        # have all fired (allows for user to validate values).
        if self.comm is not None and name in self.keys:
            # Make sure this isn't information that the front-end just sent us.
            if self._should_send_property(name, new_value):
                # Send new state to front-end
                self.send_state(key=name)

    def _handle_displayed(self, **kwargs):
        """Called when a view has been displayed for this widget instance"""
        self._display_callbacks(self, **kwargs)

    @staticmethod
    def _trait_to_json(x, self):
        """Convert a trait value to json."""
        return x

    @staticmethod
    def _trait_from_json(x, self):
        """Convert json values to objects."""
        return x

    def _ipython_display_(self, **kwargs):
        """Called when `IPython.display.display` is called on the widget."""
        # Show view.
        if self._view_name is not None:
            self._send({"method": "display"})
            self._handle_displayed(**kwargs)

    def _send(self, msg, buffers=None):
        """Sends a message to the model in the front-end."""
        self.comm.send(data=msg, buffers=buffers)
Example #18
0
class LocalAuthenticator(Authenticator):
    """Base class for Authenticators that work with local Linux/UNIX users

    Checks for local users, and can attempt to create them if they exist.
    """

    create_system_users = Bool(False,
                               help="""
        If set to True, will attempt to create local system users if they do not exist already.

        Supports Linux and BSD variants only.
        """).tag(config=True)

    add_user_cmd = Command(help="""
        The command to use for creating users as a list of strings

        For each element in the list, the string USERNAME will be replaced with
        the user's username. The username will also be appended as the final argument.

        For Linux, the default value is:

            ['adduser', '-q', '--gecos', '""', '--disabled-password']

        To specify a custom home directory, set this to:

            ['adduser', '-q', '--gecos', '""', '--home', '/customhome/USERNAME', '--disabled-password']

        This will run the command:

            adduser -q --gecos "" --home /customhome/river --disabled-password river

        when the user 'river' is created.
        """).tag(config=True)

    @default('add_user_cmd')
    def _add_user_cmd_default(self):
        """Guess the most likely-to-work adduser command for each platform"""
        if sys.platform == 'darwin':
            raise ValueError("I don't know how to create users on OS X")
        elif which('pw'):
            # Probably BSD
            return ['pw', 'useradd', '-m']
        else:
            # This appears to be the Linux non-interactive adduser command:
            return ['adduser', '-q', '--gecos', '""', '--disabled-password']

    group_whitelist = Set(help="""
        Whitelist all users from this UNIX group.

        This makes the username whitelist ineffective.
        """).tag(config=True)

    @observe('group_whitelist')
    def _group_whitelist_changed(self, change):
        """
        Log a warning if both group_whitelist and user whitelist are set.
        """
        if self.whitelist:
            self.log.warning(
                "Ignoring username whitelist because group whitelist supplied!"
            )

    def check_whitelist(self, username):
        if self.group_whitelist:
            return self.check_group_whitelist(username)
        else:
            return super().check_whitelist(username)

    def check_group_whitelist(self, username):
        """
        If group_whitelist is configured, check if authenticating user is part of group.
        """
        if not self.group_whitelist:
            return False
        for grnam in self.group_whitelist:
            try:
                group = self._getgrnam(grnam)
            except KeyError:
                self.log.error('No such group: [%s]' % grnam)
                continue
            if username in group.gr_mem:
                return True
        return False

    async def add_user(self, user):
        """Hook called whenever a new user is added

        If self.create_system_users, the user will attempt to be created if it doesn't exist.
        """
        user_exists = await maybe_future(self.system_user_exists(user))
        if not user_exists:
            if self.create_system_users:
                await maybe_future(self.add_system_user(user))
            else:
                raise KeyError(
                    "User {} does not exist on the system."
                    " Set LocalAuthenticator.create_system_users=True"
                    " to automatically create system users from jupyterhub users."
                    .format(user.name))

        await maybe_future(super().add_user(user))

    @staticmethod
    def _getgrnam(name):
        """Wrapper function to protect against `grp` not being available
        on Windows
        """
        import grp
        return grp.getgrnam(name)

    @staticmethod
    def _getpwnam(name):
        """Wrapper function to protect against `pwd` not being available
        on Windows
        """
        import pwd
        return pwd.getpwnam(name)

    @staticmethod
    def _getgrouplist(name, group):
        """Wrapper function to protect against `os._getgrouplist` not being available
        on Windows
        """
        import os
        return os.getgrouplist(name, group)

    def system_user_exists(self, user):
        """Check if the user exists on the system"""
        try:
            self._getpwnam(user.name)
        except KeyError:
            return False
        else:
            return True

    def add_system_user(self, user):
        """Create a new local UNIX user on the system.

        Tested to work on FreeBSD and Linux, at least.
        """
        name = user.name
        cmd = [arg.replace('USERNAME', name)
               for arg in self.add_user_cmd] + [name]
        self.log.info("Creating user: %s", ' '.join(map(pipes.quote, cmd)))
        p = Popen(cmd, stdout=PIPE, stderr=STDOUT)
        p.wait()
        if p.returncode:
            err = p.stdout.read().decode('utf8', 'replace')
            raise RuntimeError("Failed to create system user %s: %s" %
                               (name, err))
Example #19
0
class KernelSpecManager(LoggingConfigurable):

    kernel_spec_class = Type(
        KernelSpec,
        config=True,
        help="""The kernel spec class.  This is configurable to allow
        subclassing of the KernelSpecManager for customized behavior.
        """)

    ensure_native_kernel = Bool(
        True,
        config=True,
        help="""If there is no Python kernelspec registered and the IPython
        kernel is available, ensure it is added to the spec list.
        """)

    data_dir = Unicode()

    def _data_dir_default(self):
        return jupyter_data_dir()

    user_kernel_dir = Unicode()

    def _user_kernel_dir_default(self):
        return pjoin(self.data_dir, 'kernels')

    whitelist = Set(config=True,
                    help="""Whitelist of allowed kernel names.

        By default, all installed kernels are allowed.
        """)
    kernel_dirs = List(
        help=
        "List of kernel directories to search. Later ones take priority over earlier."
    )

    def _kernel_dirs_default(self):
        dirs = jupyter_path('kernels')
        # At some point, we should stop adding .ipython/kernels to the path,
        # but the cost to keeping it is very small.
        try:
            from IPython.paths import get_ipython_dir
        except ImportError:
            try:
                from IPython.utils.path import get_ipython_dir
            except ImportError:
                # no IPython, no ipython dir
                get_ipython_dir = None
        if get_ipython_dir is not None:
            dirs.append(os.path.join(get_ipython_dir(), 'kernels'))
        return dirs

    def find_kernel_specs(self):
        """Returns a dict mapping kernel names to resource directories."""
        d = {}
        for kernel_dir in self.kernel_dirs:
            kernels = _list_kernels_in(kernel_dir)
            for kname, spec in kernels.items():
                if kname not in d:
                    self.log.debug("Found kernel %s in %s", kname, kernel_dir)
                    d[kname] = spec

        if self.ensure_native_kernel and NATIVE_KERNEL_NAME not in d:
            try:
                from ipykernel.kernelspec import RESOURCES
                self.log.debug("Native kernel (%s) available from %s",
                               NATIVE_KERNEL_NAME, RESOURCES)
                d[NATIVE_KERNEL_NAME] = RESOURCES
            except ImportError:
                self.log.warn("Native kernel (%s) is not available",
                              NATIVE_KERNEL_NAME)

        if self.whitelist:
            # filter if there's a whitelist
            d = {
                name: spec
                for name, spec in d.items() if name in self.whitelist
            }
        return d
        # TODO: Caching?

    def _get_kernel_spec_by_name(self, kernel_name, resource_dir):
        """ Returns a :class:`KernelSpec` instance for a given kernel_name
        and resource_dir.
        """
        if kernel_name == NATIVE_KERNEL_NAME:
            try:
                from ipykernel.kernelspec import RESOURCES, get_kernel_dict
            except ImportError:
                # It should be impossible to reach this, but let's play it safe
                pass
            else:
                if resource_dir == RESOURCES:
                    return self.kernel_spec_class(resource_dir=resource_dir,
                                                  **get_kernel_dict())

        return self.kernel_spec_class.from_resource_dir(resource_dir)

    def get_kernel_spec(self, kernel_name):
        """Returns a :class:`KernelSpec` instance for the given kernel_name.

        Raises :exc:`NoSuchKernel` if the given kernel name is not found.
        """
        d = self.find_kernel_specs()
        try:
            resource_dir = d[kernel_name.lower()]
        except KeyError:
            raise NoSuchKernel(kernel_name)

        return self._get_kernel_spec_by_name(kernel_name, resource_dir)

    def get_all_specs(self):
        """Returns a dict mapping kernel names to kernelspecs.

        Returns a dict of the form::

            {
              'kernel_name': {
                'resource_dir': '/path/to/kernel_name',
                'spec': {"the spec itself": ...}
              },
              ...
            }
        """
        d = self.find_kernel_specs()
        return {
            kname: {
                "resource_dir": d[kname],
                "spec": self._get_kernel_spec_by_name(kname,
                                                      d[kname]).to_dict()
            }
            for kname in d
        }

    def remove_kernel_spec(self, name):
        """Remove a kernel spec directory by name.
        
        Returns the path that was deleted.
        """
        save_native = self.ensure_native_kernel
        try:
            self.ensure_native_kernel = False
            specs = self.find_kernel_specs()
        finally:
            self.ensure_native_kernel = save_native
        spec_dir = specs[name]
        self.log.debug("Removing %s", spec_dir)
        if os.path.islink(spec_dir):
            os.remove(spec_dir)
        else:
            shutil.rmtree(spec_dir)
        return spec_dir

    def _get_destination_dir(self, kernel_name, user=False, prefix=None):
        if user:
            return os.path.join(self.user_kernel_dir, kernel_name)
        elif prefix:
            return os.path.join(os.path.abspath(prefix), 'share', 'jupyter',
                                'kernels', kernel_name)
        else:
            return os.path.join(SYSTEM_JUPYTER_PATH[0], 'kernels', kernel_name)

    def install_kernel_spec(self,
                            source_dir,
                            kernel_name=None,
                            user=False,
                            replace=None,
                            prefix=None):
        """Install a kernel spec by copying its directory.

        If ``kernel_name`` is not given, the basename of ``source_dir`` will
        be used.

        If ``user`` is False, it will attempt to install into the systemwide
        kernel registry. If the process does not have appropriate permissions,
        an :exc:`OSError` will be raised.
        
        If ``prefix`` is given, the kernelspec will be installed to
        PREFIX/share/jupyter/kernels/KERNEL_NAME. This can be sys.prefix
        for installation inside virtual or conda envs.
        """
        source_dir = source_dir.rstrip('/\\')
        if not kernel_name:
            kernel_name = os.path.basename(source_dir)
        kernel_name = kernel_name.lower()

        if user and prefix:
            raise ValueError(
                "Can't specify both user and prefix. Please choose one or the other."
            )

        if replace is not None:
            warnings.warn(
                "replace is ignored. Installing a kernelspec always replaces an existing installation",
                DeprecationWarning,
                stacklevel=2,
            )

        destination = self._get_destination_dir(kernel_name,
                                                user=user,
                                                prefix=prefix)
        self.log.debug('Installing kernelspec in %s', destination)

        kernel_dir = os.path.dirname(destination)
        if kernel_dir not in self.kernel_dirs:
            self.log.warn(
                "Installing to %s, which is not in %s. The kernelspec may not be found.",
                kernel_dir,
                self.kernel_dirs,
            )

        if os.path.isdir(destination):
            self.log.info('Removing existing kernelspec in %s', destination)
            shutil.rmtree(destination)

        shutil.copytree(source_dir, destination)
        self.log.info('Installed kernelspec %s in %s', kernel_name,
                      destination)
        return destination

    def install_native_kernel_spec(self, user=False):
        """DEPRECATED: Use ipykernel.kenelspec.install"""
        warnings.warn(
            "install_native_kernel_spec is deprecated."
            " Use ipykernel.kernelspec import install.",
            stacklevel=2)
        from ipykernel.kernelspec import install
        install(self, user=user)
Example #20
0
class PAMAuthenticator(LocalAuthenticator):
    """Authenticate local UNIX users with PAM"""

    # run PAM in a thread, since it can be slow
    executor = Any()

    @default('executor')
    def _default_executor(self):
        return ThreadPoolExecutor(1)

    encoding = Unicode('utf8',
                       help="""
        The text encoding to use when communicating with PAM
        """).tag(config=True)

    service = Unicode('login',
                      help="""
        The name of the PAM service to use for authentication
        """).tag(config=True)

    open_sessions = Bool(True,
                         help="""
        Whether to open a new PAM session when spawners are started.

        This may trigger things like mounting shared filsystems,
        loading credentials, etc. depending on system configuration,
        but it does not always work.

        If any errors are encountered when opening/closing PAM sessions,
        this is automatically set to False.
        """).tag(config=True)

    check_account = Bool(True,
                         help="""
        Whether to check the user's account status via PAM during authentication.

        The PAM account stack performs non-authentication based account 
        management. It is typically used to restrict/permit access to a 
        service and this step is needed to access the host's user access control.

        Disabling this can be dangerous as authenticated but unauthorized users may
        be granted access and, therefore, arbitrary execution on the system.
        """).tag(config=True)

    admin_groups = Set(help="""
        Authoritative list of user groups that determine admin access.
        Users not in these groups can still be granted admin status through admin_users.

        White/blacklisting rules still apply.
        """).tag(config=True)

    def __init__(self, **kwargs):
        if pamela is None:
            raise _pamela_error from None
        super().__init__(**kwargs)

    @run_on_executor
    def is_admin(self, handler, authentication):
        """PAM admin status checker. Returns Bool to indicate user admin status."""
        # Checks upper level function (admin_users)
        admin_status = super().is_admin(handler, authentication)
        username = authentication['name']

        # If not yet listed as an admin, and admin_groups is on, use it authoritatively
        if not admin_status and self.admin_groups:
            try:
                # Most likely source of error here is a group name <-> gid mapping failure
                # This is most likely due to a typo in the configuration or in the case of LDAP/AD, a network
                # connectivity issue. Maybe a long one where the local caches have timed out, though PAM would
                # most likely would refuse to authenticate a remote user by that point.

                # It was decided that the best course of action on group resolution failure was to
                # fail to authenticate and raise instead of soft-failing and not changing admin status
                # (returning None instead of just the username) as this indicates some sort of system failure

                admin_group_gids = {
                    self._getgrnam(x).gr_gid
                    for x in self.admin_groups
                }
                user_group_gids = set(
                    self._getgrouplist(username,
                                       self._getpwnam(username).pw_gid))
                admin_status = len(admin_group_gids & user_group_gids) != 0

            except Exception as e:
                if handler is not None:
                    self.log.error("PAM Admin Group Check failed (%s@%s): %s",
                                   username, handler.request.remote_ip, e)
                else:
                    self.log.error("PAM Admin Group Check failed: %s", e)
                # re-raise to return a 500 to the user and indicate a problem. We failed, not them.
                raise

        return admin_status

    @run_on_executor
    def authenticate(self, handler, data):
        """Authenticate with PAM, and return the username if login is successful.

        Return None otherwise.
        """
        username = data['username']
        try:
            pamela.authenticate(username,
                                data['password'],
                                service=self.service,
                                encoding=self.encoding)
        except pamela.PAMError as e:
            if handler is not None:
                self.log.warning("PAM Authentication failed (%s@%s): %s",
                                 username, handler.request.remote_ip, e)
            else:
                self.log.warning("PAM Authentication failed: %s", e)
            return None

        if self.check_account:
            try:
                pamela.check_account(username,
                                     service=self.service,
                                     encoding=self.encoding)
            except pamela.PAMError as e:
                if handler is not None:
                    self.log.warning("PAM Account Check failed (%s@%s): %s",
                                     username, handler.request.remote_ip, e)
                else:
                    self.log.warning("PAM Account Check failed: %s", e)
                return None

        return username

    @run_on_executor
    def pre_spawn_start(self, user, spawner):
        """Open PAM session for user if so configured"""
        if not self.open_sessions:
            return
        try:
            pamela.open_session(user.name,
                                service=self.service,
                                encoding=self.encoding)
        except pamela.PAMError as e:
            self.log.warning("Failed to open PAM session for %s: %s",
                             user.name, e)
            self.log.warning("Disabling PAM sessions from now on.")
            self.open_sessions = False

    @run_on_executor
    def post_spawn_stop(self, user, spawner):
        """Close PAM session for user if we were configured to opened one"""
        if not self.open_sessions:
            return
        try:
            pamela.close_session(user.name,
                                 service=self.service,
                                 encoding=self.encoding)
        except pamela.PAMError as e:
            self.log.warning("Failed to close PAM session for %s: %s",
                             user.name, e)
            self.log.warning("Disabling PAM sessions from now on.")
            self.open_sessions = False
class SanitizeHTML(Preprocessor):

    # Bleach config.
    attributes = Any(
        config=True,
        default_value=ALLOWED_ATTRIBUTES,
        help="Allowed HTML tag attributes",
    )
    tags = List(
        Unicode(),
        config=True,
        default_value=ALLOWED_TAGS,
        help="List of HTML tags to allow",
    )
    styles = List(
        Unicode(),
        config=True,
        default_value=ALLOWED_STYLES,
        help="Allowed CSS styles if <style> tag is whitelisted"
    )
    strip = Bool(
        config=True,
        default_value=False,
        help="If True, remove unsafe markup entirely instead of escaping"
    )
    strip_comments = Bool(
        config=True,
        default_value=True,
        help="If True, strip comments from escaped HTML",
    )

    # Display data config.
    safe_output_keys = Set(
        config=True,
        default_value={
            'metadata',  # Not a mimetype per-se, but expected and safe.
            'text/plain',
            'text/latex',
            'application/json',
            'image/png',
            'image/jpeg',
        },
        help="Cell output mimetypes to render without modification",
    )
    sanitized_output_types = Set(
        config=True,
        default_value={
            'text/html',
            'text/markdown',
        },
        help="Cell output types to display after escaping with Bleach.",
    )

    def preprocess_cell(self, cell, resources, cell_index):
        """
        Sanitize potentially-dangerous contents of the cell.

        Cell Types:
          raw:
            Sanitize literal HTML
          markdown:
            Sanitize literal HTML
          code:
            Sanitize outputs that could result in code execution
        """
        if cell.cell_type == 'raw':
            # Sanitize all raw cells anyway.
            # Only ones with the text/html mimetype should be emitted
            # but erring on the side of safety maybe.
            cell.source = self.sanitize_html_tags(cell.source)
            return cell, resources
        elif cell.cell_type == 'markdown':
            cell.source = self.sanitize_html_tags(cell.source)
            return cell, resources
        elif cell.cell_type == 'code':
            cell.outputs = self.sanitize_code_outputs(cell.outputs)
            return cell, resources

    def sanitize_code_outputs(self, outputs):
        """
        Sanitize code cell outputs.

        Removes 'text/javascript' fields from display_data outputs, and
        runs `sanitize_html_tags` over 'text/html'.
        """
        for output in outputs:
            # These are always ascii, so nothing to escape.
            if output['output_type'] in ('stream', 'error'):
                continue
            data = output.data
            to_remove = []
            for key in data:
                if key in self.safe_output_keys:
                    continue
                elif key in self.sanitized_output_types:
                    self.log.info("Sanitizing %s" % key)
                    data[key] = self.sanitize_html_tags(data[key])
                else:
                    # Mark key for removal. (Python doesn't allow deletion of
                    # keys from a dict during iteration)
                    to_remove.append(key)
            for key in to_remove:
                self.log.info("Removing %s" % key)
                del data[key]
        return outputs

    def sanitize_html_tags(self, html_str):
        """
        Sanitize a string containing raw HTML tags.
        """
        return clean(
            html_str,
            tags=self.tags,
            attributes=self.attributes,
            styles=self.styles,
            strip=self.strip,
            strip_comments=self.strip_comments,
        )
Example #22
0
class View(HasTraits):
    """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.

    Don't use this class, use subclasses.

    Methods
    -------

    spin
        flushes incoming results and registration state changes
        control methods spin, and requesting `ids` also ensures up to date

    wait
        wait on one or more msg_ids

    execution methods
        apply
        legacy: execute, run

    data movement
        push, pull, scatter, gather

    query methods
        get_result, queue_status, purge_results, result_status

    control methods
        abort, shutdown

    """
    # flags
    block = Bool(False)
    track = Bool(False)
    targets = Any()

    history = List()
    outstanding = Set()
    results = Dict()
    client = Instance('ipyparallel.Client', allow_none=True)

    _socket = Instance('zmq.Socket', allow_none=True)
    _flag_names = List(['targets', 'block', 'track'])
    _in_sync_results = Bool(False)
    _targets = Any()
    _idents = Any()

    def __init__(self, client=None, socket=None, **flags):
        super(View, self).__init__(client=client, _socket=socket)
        self.results = client.results
        self.block = client.block
        self.executor = ViewExecutor(self)

        self.set_flags(**flags)

        assert not self.__class__ is View, "Don't use base View objects, use subclasses"

    def __repr__(self):
        strtargets = str(self.targets)
        if len(strtargets) > 16:
            strtargets = strtargets[:12]+'...]'
        return "<%s %s>"%(self.__class__.__name__, strtargets)
    
    def __len__(self):
        if isinstance(self.targets, list):
            return len(self.targets)
        elif isinstance(self.targets, int):
            return 1
        else:
            return len(self.client)
    
    def set_flags(self, **kwargs):
        """set my attribute flags by keyword.

        Views determine behavior with a few attributes (`block`, `track`, etc.).
        These attributes can be set all at once by name with this method.

        Parameters
        ----------

        block : bool
            whether to wait for results
        track : bool
            whether to create a MessageTracker to allow the user to
            safely edit after arrays and buffers during non-copying
            sends.
        """
        for name, value in iteritems(kwargs):
            if name not in self._flag_names:
                raise KeyError("Invalid name: %r"%name)
            else:
                setattr(self, name, value)

    @contextmanager
    def temp_flags(self, **kwargs):
        """temporarily set flags, for use in `with` statements.

        See set_flags for permanent setting of flags

        Examples
        --------

        >>> view.track=False
        ...
        >>> with view.temp_flags(track=True):
        ...    ar = view.apply(dostuff, my_big_array)
        ...    ar.tracker.wait() # wait for send to finish
        >>> view.track
        False

        """
        # preflight: save flags, and set temporaries
        saved_flags = {}
        for f in self._flag_names:
            saved_flags[f] = getattr(self, f)
        self.set_flags(**kwargs)
        # yield to the with-statement block
        try:
            yield
        finally:
            # postflight: restore saved flags
            self.set_flags(**saved_flags)


    #----------------------------------------------------------------
    # apply
    #----------------------------------------------------------------

    def _sync_results(self):
        """to be called by @sync_results decorator
        
        after submitting any tasks.
        """
        delta = self.outstanding.difference(self.client.outstanding)
        completed = self.outstanding.intersection(delta)
        self.outstanding = self.outstanding.difference(completed)
    
    @sync_results
    @save_ids
    def _really_apply(self, f, args, kwargs, block=None, **options):
        """wrapper for client.send_apply_request"""
        raise NotImplementedError("Implement in subclasses")

    def apply(self, f, *args, **kwargs):
        """calls ``f(*args, **kwargs)`` on remote engines, returning the result.

        This method sets all apply flags via this View's attributes.

        Returns :class:`~ipyparallel.client.asyncresult.AsyncResult`
        instance if ``self.block`` is False, otherwise the return value of
        ``f(*args, **kwargs)``.
        """
        return self._really_apply(f, args, kwargs)

    def apply_async(self, f, *args, **kwargs):
        """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner.

        Returns :class:`~ipyparallel.client.asyncresult.AsyncResult` instance.
        """
        return self._really_apply(f, args, kwargs, block=False)

    def apply_sync(self, f, *args, **kwargs):
        """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner,
         returning the result.
        """
        return self._really_apply(f, args, kwargs, block=True)

    #----------------------------------------------------------------
    # wrappers for client and control methods
    #----------------------------------------------------------------
    @sync_results
    def spin(self):
        """spin the client, and sync"""
        self.client.spin()

    @sync_results
    def wait(self, jobs=None, timeout=-1):
        """waits on one or more `jobs`, for up to `timeout` seconds.

        Parameters
        ----------

        jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
                ints are indices to self.history
                strs are msg_ids
                default: wait on all outstanding messages
        timeout : float
                a time in seconds, after which to give up.
                default is -1, which means no timeout

        Returns
        -------

        True : when all msg_ids are done
        False : timeout reached, some msg_ids still outstanding
        """
        if jobs is None:
            jobs = self.history
        return self.client.wait(jobs, timeout)

    def abort(self, jobs=None, targets=None, block=None):
        """Abort jobs on my engines.

        Parameters
        ----------

        jobs : None, str, list of strs, optional
            if None: abort all jobs.
            else: abort specific msg_id(s).
        """
        block = block if block is not None else self.block
        targets = targets if targets is not None else self.targets
        jobs = jobs if jobs is not None else list(self.outstanding)
        
        return self.client.abort(jobs=jobs, targets=targets, block=block)

    def queue_status(self, targets=None, verbose=False):
        """Fetch the Queue status of my engines"""
        targets = targets if targets is not None else self.targets
        return self.client.queue_status(targets=targets, verbose=verbose)

    def purge_results(self, jobs=[], targets=[]):
        """Instruct the controller to forget specific results."""
        if targets is None or targets == 'all':
            targets = self.targets
        return self.client.purge_results(jobs=jobs, targets=targets)

    def shutdown(self, targets=None, restart=False, hub=False, block=None):
        """Terminates one or more engine processes, optionally including the hub.
        """
        block = self.block if block is None else block
        if targets is None or targets == 'all':
            targets = self.targets
        return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)

    def get_result(self, indices_or_msg_ids=None, block=None, owner=False):
        """return one or more results, specified by history index or msg_id.

        See :meth:`ipyparallel.client.client.Client.get_result` for details.
        """

        if indices_or_msg_ids is None:
            indices_or_msg_ids = -1
        if isinstance(indices_or_msg_ids, int):
            indices_or_msg_ids = self.history[indices_or_msg_ids]
        elif isinstance(indices_or_msg_ids, (list,tuple,set)):
            indices_or_msg_ids = list(indices_or_msg_ids)
            for i,index in enumerate(indices_or_msg_ids):
                if isinstance(index, int):
                    indices_or_msg_ids[i] = self.history[index]
        return self.client.get_result(indices_or_msg_ids, block=block, owner=owner)

    #-------------------------------------------------------------------
    # Map
    #-------------------------------------------------------------------

    @sync_results
    def map(self, f, *sequences, **kwargs):
        """override in subclasses"""
        raise NotImplementedError

    def map_async(self, f, *sequences, **kwargs):
        """Parallel version of builtin :func:`python:map`, using this view's engines.

        This is equivalent to ``map(...block=False)``.

        See `self.map` for details.
        """
        if 'block' in kwargs:
            raise TypeError("map_async doesn't take a `block` keyword argument.")
        kwargs['block'] = False
        return self.map(f,*sequences,**kwargs)

    def map_sync(self, f, *sequences, **kwargs):
        """Parallel version of builtin :func:`python:map`, using this view's engines.

        This is equivalent to ``map(...block=True)``.

        See `self.map` for details.
        """
        if 'block' in kwargs:
            raise TypeError("map_sync doesn't take a `block` keyword argument.")
        kwargs['block'] = True
        return self.map(f,*sequences,**kwargs)

    def imap(self, f, *sequences, **kwargs):
        """Parallel version of :func:`itertools.imap`.

        See `self.map` for details.

        """

        return iter(self.map_async(f,*sequences, **kwargs))

    #-------------------------------------------------------------------
    # Decorators
    #-------------------------------------------------------------------

    def remote(self, block=None, **flags):
        """Decorator for making a RemoteFunction"""
        block = self.block if block is None else block
        return remote(self, block=block, **flags)

    def parallel(self, dist='b', block=None, **flags):
        """Decorator for making a ParallelFunction"""
        block = self.block if block is None else block
        return parallel(self, dist=dist, block=block, **flags)
Example #23
0
class BaseIPythonApplication(Application):

    name = u'ipython'
    description = Unicode(u'IPython: an enhanced interactive Python shell.')
    version = Unicode(release.version)

    aliases = base_aliases
    flags = base_flags
    classes = List([ProfileDir])

    # enable `load_subconfig('cfg.py', profile='name')`
    python_config_loader_class = ProfileAwareConfigLoader

    # Track whether the config_file has changed,
    # because some logic happens only if we aren't using the default.
    config_file_specified = Set()

    config_file_name = Unicode()

    @default('config_file_name')
    def _config_file_name_default(self):
        return self.name.replace('-', '_') + u'_config.py'

    @observe('config_file_name')
    def _config_file_name_changed(self, change):
        if change['new'] != change['old']:
            self.config_file_specified.add(change['new'])

    # The directory that contains IPython's builtin profiles.
    builtin_profile_dir = Unicode(
        os.path.join(get_ipython_package_dir(), u'config', u'profile',
                     u'default'))

    config_file_paths = List(Unicode())

    @default('config_file_paths')
    def _config_file_paths_default(self):
        return [os.getcwd()]

    extra_config_file = Unicode(help="""Path to an extra config file to load.
    
    If specified, load this config file in addition to any other IPython config.
    """).tag(config=True)

    @observe('extra_config_file')
    def _extra_config_file_changed(self, change):
        old = change['old']
        new = change['new']
        try:
            self.config_files.remove(old)
        except ValueError:
            pass
        self.config_file_specified.add(new)
        self.config_files.append(new)

    profile = Unicode(u'default',
                      help="""The IPython profile to use.""").tag(config=True)

    @observe('profile')
    def _profile_changed(self, change):
        self.builtin_profile_dir = os.path.join(get_ipython_package_dir(),
                                                u'config', u'profile',
                                                change['new'])

    ipython_dir = Unicode(help="""
        The name of the IPython directory. This directory is used for logging
        configuration (through profiles), history storage, etc. The default
        is usually $HOME/.ipython. This option can also be specified through
        the environment variable IPYTHONDIR.
        """).tag(config=True)

    @default('ipython_dir')
    def _ipython_dir_default(self):
        d = get_ipython_dir()
        self._ipython_dir_changed({
            'name': 'ipython_dir',
            'old': d,
            'new': d,
        })
        return d

    _in_init_profile_dir = False
    profile_dir = Instance(ProfileDir, allow_none=True)

    @default('profile_dir')
    def _profile_dir_default(self):
        # avoid recursion
        if self._in_init_profile_dir:
            return
        # profile_dir requested early, force initialization
        self.init_profile_dir()
        return self.profile_dir

    overwrite = Bool(
        False,
        help="""Whether to overwrite existing config files when copying"""
    ).tag(config=True)
    auto_create = Bool(
        False,
        help="""Whether to create profile dir if it doesn't exist""").tag(
            config=True)

    config_files = List(Unicode())

    @default('config_files')
    def _config_files_default(self):
        return [self.config_file_name]

    copy_config_files = Bool(
        False,
        help="""Whether to install the default config files into the profile dir.
        If a new profile is being created, and IPython contains config files for that
        profile, then they will be staged into the new directory.  Otherwise,
        default config files will be automatically generated.
        """).tag(config=True)

    verbose_crash = Bool(
        False,
        help=
        """Create a massive crash report when IPython encounters what may be an
        internal error.  The default is to append a short message to the
        usual traceback""").tag(config=True)

    # The class to use as the crash handler.
    crash_handler_class = Type(crashhandler.CrashHandler)

    @catch_config_error
    def __init__(self, **kwargs):
        super(BaseIPythonApplication, self).__init__(**kwargs)
        # ensure current working directory exists
        try:
            os.getcwd()
        except:
            # exit if cwd doesn't exist
            self.log.error("Current working directory doesn't exist.")
            self.exit(1)

    #-------------------------------------------------------------------------
    # Various stages of Application creation
    #-------------------------------------------------------------------------

    deprecated_subcommands = {}

    def initialize_subcommand(self, subc, argv=None):
        if subc in self.deprecated_subcommands:
            self.log.warning(
                "Subcommand `ipython {sub}` is deprecated and will be removed "
                "in future versions.".format(sub=subc))
            self.log.warning("You likely want to use `jupyter {sub}` in the "
                             "future".format(sub=subc))
        return super(BaseIPythonApplication,
                     self).initialize_subcommand(subc, argv)

    def init_crash_handler(self):
        """Create a crash handler, typically setting sys.excepthook to it."""
        self.crash_handler = self.crash_handler_class(self)
        sys.excepthook = self.excepthook

        def unset_crashhandler():
            sys.excepthook = sys.__excepthook__

        atexit.register(unset_crashhandler)

    def excepthook(self, etype, evalue, tb):
        """this is sys.excepthook after init_crashhandler

        set self.verbose_crash=True to use our full crashhandler, instead of
        a regular traceback with a short message (crash_handler_lite)
        """

        if self.verbose_crash:
            return self.crash_handler(etype, evalue, tb)
        else:
            return crashhandler.crash_handler_lite(etype, evalue, tb)

    @observe('ipython_dir')
    def _ipython_dir_changed(self, change):
        old = change['old']
        new = change['new']
        if old is not Undefined:
            str_old = os.path.abspath(old)
            if str_old in sys.path:
                sys.path.remove(str_old)
        str_path = os.path.abspath(new)
        sys.path.append(str_path)
        ensure_dir_exists(new)
        readme = os.path.join(new, 'README')
        readme_src = os.path.join(get_ipython_package_dir(), u'config',
                                  u'profile', 'README')
        if not os.path.exists(readme) and os.path.exists(readme_src):
            shutil.copy(readme_src, readme)
        for d in ('extensions', 'nbextensions'):
            path = os.path.join(new, d)
            try:
                ensure_dir_exists(path)
            except OSError as e:
                # this will not be EEXIST
                self.log.error("couldn't create path %s: %s", path, e)
        self.log.debug("IPYTHONDIR set to: %s" % new)

    def load_config_file(self, suppress_errors=IPYTHON_SUPPRESS_CONFIG_ERRORS):
        """Load the config file.

        By default, errors in loading config are handled, and a warning
        printed on screen. For testing, the suppress_errors option is set
        to False, so errors will make tests fail.

        `suppress_errors` default value is to be `None` in which case the
        behavior default to the one of `traitlets.Application`.

        The default value can be set :
           - to `False` by setting 'IPYTHON_SUPPRESS_CONFIG_ERRORS' environment variable to '0', or 'false' (case insensitive).
           - to `True` by setting 'IPYTHON_SUPPRESS_CONFIG_ERRORS' environment variable to '1' or 'true' (case insensitive).
           - to `None` by setting 'IPYTHON_SUPPRESS_CONFIG_ERRORS' environment variable to '' (empty string) or leaving it unset.

        Any other value are invalid, and will make IPython exit with a non-zero return code.
        """

        self.log.debug("Searching path %s for config files",
                       self.config_file_paths)
        base_config = 'ipython_config.py'
        self.log.debug("Attempting to load config file: %s" % base_config)
        try:
            if suppress_errors is not None:
                old_value = Application.raise_config_file_errors
                Application.raise_config_file_errors = not suppress_errors
            Application.load_config_file(self,
                                         base_config,
                                         path=self.config_file_paths)
        except ConfigFileNotFound:
            # ignore errors loading parent
            self.log.debug("Config file %s not found", base_config)
            pass
        if suppress_errors is not None:
            Application.raise_config_file_errors = old_value

        for config_file_name in self.config_files:
            if not config_file_name or config_file_name == base_config:
                continue
            self.log.debug("Attempting to load config file: %s" %
                           self.config_file_name)
            try:
                Application.load_config_file(self,
                                             config_file_name,
                                             path=self.config_file_paths)
            except ConfigFileNotFound:
                # Only warn if the default config file was NOT being used.
                if config_file_name in self.config_file_specified:
                    msg = self.log.warning
                else:
                    msg = self.log.debug
                msg("Config file not found, skipping: %s", config_file_name)
            except Exception:
                # For testing purposes.
                if not suppress_errors:
                    raise
                self.log.warning("Error loading config file: %s" %
                                 self.config_file_name,
                                 exc_info=True)

    def init_profile_dir(self):
        """initialize the profile dir"""
        self._in_init_profile_dir = True
        if self.profile_dir is not None:
            # already ran
            return
        if 'ProfileDir.location' not in self.config:
            # location not specified, find by profile name
            try:
                p = ProfileDir.find_profile_dir_by_name(
                    self.ipython_dir, self.profile, self.config)
            except ProfileDirError:
                # not found, maybe create it (always create default profile)
                if self.auto_create or self.profile == 'default':
                    try:
                        p = ProfileDir.create_profile_dir_by_name(
                            self.ipython_dir, self.profile, self.config)
                    except ProfileDirError:
                        self.log.fatal("Could not create profile: %r" %
                                       self.profile)
                        self.exit(1)
                    else:
                        self.log.info("Created profile dir: %r" % p.location)
                else:
                    self.log.fatal("Profile %r not found." % self.profile)
                    self.exit(1)
            else:
                self.log.debug(f"Using existing profile dir: {p.location!r}")
        else:
            location = self.config.ProfileDir.location
            # location is fully specified
            try:
                p = ProfileDir.find_profile_dir(location, self.config)
            except ProfileDirError:
                # not found, maybe create it
                if self.auto_create:
                    try:
                        p = ProfileDir.create_profile_dir(
                            location, self.config)
                    except ProfileDirError:
                        self.log.fatal(
                            "Could not create profile directory: %r" %
                            location)
                        self.exit(1)
                    else:
                        self.log.debug("Creating new profile dir: %r" %
                                       location)
                else:
                    self.log.fatal("Profile directory %r not found." %
                                   location)
                    self.exit(1)
            else:
                self.log.debug(f"Using existing profile dir: {p.location!r}")
            # if profile_dir is specified explicitly, set profile name
            dir_name = os.path.basename(p.location)
            if dir_name.startswith('profile_'):
                self.profile = dir_name[8:]

        self.profile_dir = p
        self.config_file_paths.append(p.location)
        self._in_init_profile_dir = False

    def init_config_files(self):
        """[optionally] copy default config files into profile dir."""
        self.config_file_paths.extend(ENV_CONFIG_DIRS)
        self.config_file_paths.extend(SYSTEM_CONFIG_DIRS)
        # copy config files
        path = Path(self.builtin_profile_dir)
        if self.copy_config_files:
            src = self.profile

            cfg = self.config_file_name
            if path and (path / cfg).exists():
                self.log.warning(
                    "Staging %r from %s into %r [overwrite=%s]" %
                    (cfg, src, self.profile_dir.location, self.overwrite))
                self.profile_dir.copy_config_file(cfg,
                                                  path=path,
                                                  overwrite=self.overwrite)
            else:
                self.stage_default_config_file()
        else:
            # Still stage *bundled* config files, but not generated ones
            # This is necessary for `ipython profile=sympy` to load the profile
            # on the first go
            files = path.glob("*.py")
            for fullpath in files:
                cfg = fullpath.name
                if self.profile_dir.copy_config_file(cfg,
                                                     path=path,
                                                     overwrite=False):
                    # file was copied
                    self.log.warning(
                        "Staging bundled %s from %s into %r" %
                        (cfg, self.profile, self.profile_dir.location))

    def stage_default_config_file(self):
        """auto generate default config file, and stage it into the profile."""
        s = self.generate_config_file()
        config_file = Path(self.profile_dir.location) / self.config_file_name
        if self.overwrite or not config_file.exists():
            self.log.warning("Generating default config file: %r" %
                             (config_file))
            config_file.write_text(s)

    @catch_config_error
    def initialize(self, argv=None):
        # don't hook up crash handler before parsing command-line
        self.parse_command_line(argv)
        self.init_crash_handler()
        if self.subapp is not None:
            # stop here if subapp is taking over
            return
        # save a copy of CLI config to re-load after config files
        # so that it has highest priority
        cl_config = deepcopy(self.config)
        self.init_profile_dir()
        self.init_config_files()
        self.load_config_file()
        # enforce cl-opts override configfile opts:
        self.update_config(cl_config)
Example #24
0
class LanguageServerSession(LoggingConfigurable):
    """Manage a session for a connection to a language server"""

    language_server = Unicode(help="the language server implementation name")
    spec = Schema(LANGUAGE_SERVER_SPEC)

    # run-time specifics
    process = Instance(subprocess.Popen,
                       help="the language server subprocess",
                       allow_none=True)
    writer = Instance(stdio.LspStdIoWriter,
                      help="the JSON-RPC writer",
                      allow_none=True)
    reader = Instance(stdio.LspStdIoReader,
                      help="the JSON-RPC reader",
                      allow_none=True)
    from_lsp = Instance(Queue,
                        help="a queue for string messages from the server",
                        allow_none=True)
    to_lsp = Instance(Queue,
                      help="a queue for string message to the server",
                      allow_none=True)
    handlers = Set(
        trait=Instance(WebSocketHandler),
        default_value=[],
        help="the currently subscribed websockets",
    )
    status = UseEnum(SessionStatus, default_value=SessionStatus.NOT_STARTED)
    last_handler_message_at = Instance(datetime, allow_none=True)
    last_server_message_at = Instance(datetime, allow_none=True)

    _tasks = None

    _skip_serialize = ["argv", "debug_argv"]

    def __init__(self, *args, **kwargs):
        """set up the required traitlets and exit behavior for a session"""
        super().__init__(*args, **kwargs)
        atexit.register(self.stop)

    def __repr__(self):  # pragma: no cover
        return ("<LanguageServerSession("
                "language_server={language_server}, argv={argv})>").format(
                    language_server=self.language_server, **self.spec)

    def to_json(self):
        return dict(
            handler_count=len(self.handlers),
            status=self.status.value,
            last_server_message_at=self.last_server_message_at.isoformat()
            if self.last_server_message_at else None,
            last_handler_message_at=self.last_handler_message_at.isoformat()
            if self.last_handler_message_at else None,
            spec={
                k: v
                for k, v in self.spec.items() if k not in SKIP_JSON_SPEC
            },
        )

    def initialize(self):
        """(re)initialize a language server session"""
        self.stop()
        self.status = SessionStatus.STARTING
        self.init_queues()
        self.init_process()
        self.init_writer()
        self.init_reader()

        loop = asyncio.get_event_loop()
        self._tasks = [
            loop.create_task(coro()) for coro in
            [self._read_lsp, self._write_lsp, self._broadcast_from_lsp]
        ]

        self.status = SessionStatus.STARTED

    def stop(self):
        """clean up all of the state of the session"""

        self.status = SessionStatus.STOPPING

        if self.process:
            self.process.terminate()
            self.process = None
        if self.reader:
            self.reader.close()
            self.reader = None
        if self.writer:
            self.writer.close()
            self.writer = None

        if self._tasks:
            [task.cancel() for task in self._tasks]

        self.status = SessionStatus.STOPPED

    @observe("handlers")
    def _on_handlers(self, change: Bunch):
        """re-initialize if someone starts listening, or stop if nobody is"""
        if change["new"] and not self.process:
            self.initialize()
        elif not change["new"] and self.process:
            self.stop()

    def write(self, message):
        """wrapper around the write queue to keep it mostly internal"""
        self.last_handler_message_at = self.now()
        IOLoop.current().add_callback(self.to_lsp.put_nowait, message)

    def now(self):
        return datetime.now(timezone.utc)

    def init_process(self):
        """start the language server subprocess"""
        self.process = subprocess.Popen(
            self.spec["argv"],
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            env=self.substitute_env(self.spec.get("env", {}), os.environ),
        )

    def init_queues(self):
        """create the queues"""
        self.from_lsp = Queue()
        self.to_lsp = Queue()

    def init_reader(self):
        """create the stdout reader (from the language server)"""
        self.reader = stdio.LspStdIoReader(stream=self.process.stdout,
                                           queue=self.from_lsp,
                                           parent=self)

    def init_writer(self):
        """create the stdin writer (to the language server)"""
        self.writer = stdio.LspStdIoWriter(stream=self.process.stdin,
                                           queue=self.to_lsp,
                                           parent=self)

    def substitute_env(self, env, base):
        final_env = copy(os.environ)

        for key, value in env.items():
            final_env.update(
                {key: string.Template(value).safe_substitute(base)})

        return final_env

    async def _read_lsp(self):
        await self.reader.read()

    async def _write_lsp(self):
        await self.writer.write()

    async def _broadcast_from_lsp(self):
        """loop for reading messages from the queue of messages from the language
        server
        """
        async for message in self.from_lsp:
            self.last_server_message_at = self.now()
            await self.parent.on_server_message(message, self)
            self.from_lsp.task_done()
Example #25
0
class EnterpriseGatewayApp(KernelGatewayApp):
    """Application that provisions Jupyter kernels and proxies HTTP/Websocket
    traffic to the kernels.

    - reads command line and environment variable settings
    - initializes managers and routes
    - creates a Tornado HTTP server
    - starts the Tornado event loop
    """
    name = 'jupyter-enterprise-gateway'
    version = __version__
    description = """
        Jupyter Enterprise Gateway

        Provisions remote Jupyter kernels and proxies HTTP/Websocket traffic
        to them.
    """

    # Remote hosts
    remote_hosts_env = 'EG_REMOTE_HOSTS'
    remote_hosts_default_value = 'localhost'
    remote_hosts = List(
        default_value=[remote_hosts_default_value],
        config=True,
        help=
        """Bracketed comma-separated list of hosts on which DistributedProcessProxy kernels will be launched
          e.g., ['host1','host2']. (EG_REMOTE_HOSTS env var - non-bracketed, just comma-separated)"""
    )

    @default('remote_hosts')
    def remote_hosts_default(self):
        return os.getenv(self.remote_hosts_env,
                         self.remote_hosts_default_value).split(',')

    # Yarn endpoint
    yarn_endpoint_env = 'EG_YARN_ENDPOINT'
    yarn_endpoint_default_value = 'http://*****:*****@default('yarn_endpoint')
    def yarn_endpoint_default(self):
        return os.getenv(self.yarn_endpoint_env,
                         self.yarn_endpoint_default_value)

    yarn_endpoint_security_enabled_env = 'EG_YARN_ENDPOINT_SECURITY_ENABLED'
    yarn_endpoint_security_enabled_default_value = False
    yarn_endpoint_security_enabled = Bool(
        yarn_endpoint_security_enabled_default_value,
        config=True,
        help=
        """Is YARN Kerberos/SPNEGO Security enabled (True/False). (EG_YARN_ENDPOINT_SECURITY_ENABLED env var)"""
    )

    @default('yarn_endpoint_security_enabled')
    def yarn_endpoint_security_enabled_default(self):
        return bool(
            os.getenv(self.yarn_endpoint_security_enabled_env,
                      self.yarn_endpoint_security_enabled_default_value))

    # Conductor endpoint
    conductor_endpoint_env = 'EG_CONDUCTOR_ENDPOINT'
    conductor_endpoint_default_value = None
    conductor_endpoint = Unicode(
        conductor_endpoint_default_value,
        config=True,
        help=
        """The http url for accessing the Conductor REST API. (EG_CONDUCTOR_ENDPOINT env var)"""
    )

    @default('conductor_endpoint')
    def conductor_endpoint_default(self):
        return os.getenv(self.conductor_endpoint_env,
                         self.conductor_endpoint_default_value)

    _log_formatter_cls = LogFormatter

    @default('log_format')
    def _default_log_format(self):
        """override default log format to include milliseconds"""
        return u"%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s]%(end_color)s %(message)s"

    # Impersonation enabled
    impersonation_enabled_env = 'EG_IMPERSONATION_ENABLED'
    impersonation_enabled = Bool(
        False,
        config=True,
        help=
        """Indicates whether impersonation will be performed during kernel launch. 
        (EG_IMPERSONATION_ENABLED env var)""")

    @default('impersonation_enabled')
    def impersonation_enabled_default(self):
        return bool(
            os.getenv(self.impersonation_enabled_env, 'false').lower() ==
            'true')

    # Unauthorized users
    unauthorized_users_env = 'EG_UNAUTHORIZED_USERS'
    unauthorized_users_default_value = 'root'
    unauthorized_users = Set(
        default_value={unauthorized_users_default_value},
        config=True,
        help=
        """Comma-separated list of user names (e.g., ['root','admin']) against which KERNEL_USERNAME
        will be compared.  Any match (case-sensitive) will prevent the kernel's launch
        and result in an HTTP 403 (Forbidden) error. (EG_UNAUTHORIZED_USERS env var - non-bracketed, just
        comma-separated)""")

    @default('unauthorized_users')
    def unauthorized_users_default(self):
        return os.getenv(self.unauthorized_users_env,
                         self.unauthorized_users_default_value).split(',')

    # Authorized users
    authorized_users_env = 'EG_AUTHORIZED_USERS'
    authorized_users = Set(
        config=True,
        help=
        """Comma-separated list of user names (e.g., ['bob','alice']) against which KERNEL_USERNAME
        will be compared.  Any match (case-sensitive) will allow the kernel's launch, otherwise an HTTP 403
        (Forbidden) error will be raised.  The set of unauthorized users takes precedence. This option should
        be used carefully as it can dramatically limit who can launch kernels.  (EG_AUTHORIZED_USERS
        env var - non-bracketed, just comma-separated)""")

    @default('authorized_users')
    def authorized_users_default(self):
        au_env = os.getenv(self.authorized_users_env)
        return au_env.split(',') if au_env is not None else []

    # Port range
    port_range_env = 'EG_PORT_RANGE'
    port_range_default_value = "0..0"
    port_range = Unicode(
        port_range_default_value,
        config=True,
        help=
        """Specifies the lower and upper port numbers from which ports are created.  The bounded values
        are separated by '..' (e.g., 33245..34245 specifies a range of 1000 ports to be randomly selected).
        A range of zero (e.g., 33245..33245 or 0..0) disables port-range enforcement.  (EG_PORT_RANGE env var)"""
    )

    @default('port_range')
    def port_range_default(self):
        return os.getenv(self.port_range_env, self.port_range_default_value)

    # Max Kernels per User
    max_kernels_per_user_env = 'EG_MAX_KERNELS_PER_USER'
    max_kernels_per_user_default_value = -1
    max_kernels_per_user = Integer(
        max_kernels_per_user_default_value,
        config=True,
        help=
        """Specifies the maximum number of kernels a user can have active simultaneously.  A value of -1 disables
        enforcement.  (EG_MAX_KERNELS_PER_USER env var)""")

    @default('max_kernels_per_user')
    def max_kernels_per_user_default(self):
        return int(
            os.getenv(self.max_kernels_per_user_env,
                      self.max_kernels_per_user_default_value))

    kernel_spec_manager = Instance(KernelSpecManager, allow_none=True)

    kernel_spec_manager_class = Type(default_value=KernelSpecManager,
                                     config=True,
                                     help="""
        The kernel spec manager class to use. Should be a subclass
        of `jupyter_client.kernelspec.KernelSpecManager`.
        """)

    kernel_manager_class = Type(klass=MappingKernelManager,
                                default_value=RemoteMappingKernelManager,
                                config=True,
                                help="""
        The kernel manager class to use. Should be a subclass
        of `notebook.services.kernels.MappingKernelManager`.
        """)

    def init_configurables(self):
        """Initializes all configurable objects including a kernel manager, kernel
        spec manager, session manager, and personality.

        Any kernel pool configured by the personality will be its responsibility
        to shut down.

        Optionally, loads a notebook and prespawns the configured number of
        kernels.
        """
        self.kernel_spec_manager = KernelSpecManager(parent=self)

        self.seed_notebook = None
        if self.seed_uri is not None:
            # Note: must be set before instantiating a SeedingMappingKernelManager
            self.seed_notebook = self._load_notebook(self.seed_uri)

        # Only pass a default kernel name when one is provided. Otherwise,
        # adopt whatever default the kernel manager wants to use.
        kwargs = {}
        if self.default_kernel_name:
            kwargs['default_kernel_name'] = self.default_kernel_name

        self.kernel_spec_manager = self.kernel_spec_manager_class(
            parent=self, )

        self.kernel_manager = self.kernel_manager_class(
            parent=self,
            log=self.log,
            connection_dir=self.runtime_dir,
            kernel_spec_manager=self.kernel_spec_manager,
            **kwargs)

        # Detect older version of notebook
        func = getattr(self.kernel_manager, 'initialize_culler', None)
        if not func:
            self.log.warning(
                "Older version of Notebook detected - idle kernels will not be culled.  "
                "Culling requires Notebook >= 5.1.0.")

        self.session_manager = SessionManager(
            log=self.log, kernel_manager=self.kernel_manager)

        self.kernel_session_manager = KernelSessionManager(
            log=self.log,
            kernel_manager=self.kernel_manager,
            config=self.config,  # required to get command-line options visible
            **kwargs)
        # Attempt to start persisted sessions
        self.kernel_session_manager.start_sessions()

        self.contents_manager = None

        if self.prespawn_count:
            if self.max_kernels and self.prespawn_count > self.max_kernels:
                raise RuntimeError(
                    'cannot prespawn {}; more than max kernels {}'.format(
                        self.prespawn_count, self.max_kernels))

        api_module = self._load_api_module(self.api)
        func = getattr(api_module, 'create_personality')
        self.personality = func(parent=self, log=self.log)

        self.personality.init_configurables()

    def init_webapp(self):
        super(EnterpriseGatewayApp, self).init_webapp()

        # As of Notebook 5.6, remote kernels are prevented: https://github.com/jupyter/notebook/pull/3714/ unless
        # 'allow_remote_access' is enabled.  Since this is the entire purpose of EG, we'll unconditionally set that
        # here.  Because this is a dictionary, we shouldn't have to worry about older versions as this will be ignored.
        self.web_app.settings['allow_remote_access'] = True

    def start(self):
        """Starts an IO loop for the application. """

        # Note that we *intentionally* reference the KernelGatewayApp so that we bypass
        # its start() logic and just call that of JKG's superclass.
        super(KernelGatewayApp, self).start()

        self.log.info(
            'Jupyter Enterprise Gateway {} is available at http{}://{}:{}'.
            format(EnterpriseGatewayApp.version, 's' if self.keyfile else '',
                   self.ip, self.port))
        # If impersonation is enabled, issue a warning message if the gateway user is not in unauthorized_users.
        if self.impersonation_enabled:
            gateway_user = getpass.getuser()
            if gateway_user.lower() not in self.unauthorized_users:
                self.log.warning(
                    "Impersonation is enabled and gateway user '{}' is NOT specified in the set of "
                    "unauthorized users!  Kernels may execute as that user with elevated privileges."
                    .format(gateway_user))

        self.io_loop = ioloop.IOLoop.current()

        signal.signal(signal.SIGHUP, signal.SIG_IGN)

        signal.signal(signal.SIGTERM, self._signal_stop)

        try:
            self.io_loop.start()
        except KeyboardInterrupt:
            self.log.info("Interrupted...")
            # Ignore further interrupts (ctrl-c)
            signal.signal(signal.SIGINT, signal.SIG_IGN)
        finally:
            self.shutdown()

    def stop(self):
        """
        Stops the HTTP server and IO loop associated with the application.
        """
        def _stop():
            self.http_server.stop()
            self.io_loop.stop()

        self.io_loop.add_callback(_stop)

    def _signal_stop(self, sig, frame):
        self.log.info("Received signal to terminate Enterprise Gateway.")
        self.io_loop.stop()
Example #26
0
class Widget(LoggingConfigurable):
    #-------------------------------------------------------------------------
    # Class attributes
    #-------------------------------------------------------------------------
    _widget_construction_callback = None
    widgets = {}
    widget_types = {}

    @staticmethod
    def on_widget_constructed(callback):
        """Registers a callback to be called when a widget is constructed.

        The callback must have the following signature:
        callback(widget)"""
        Widget._widget_construction_callback = callback

    @staticmethod
    def _call_widget_constructed(widget):
        """Static method, called when a widget is constructed."""
        if Widget._widget_construction_callback is not None and callable(
                Widget._widget_construction_callback):
            Widget._widget_construction_callback(widget)

    @staticmethod
    def handle_comm_opened(comm, msg):
        """Static method, called when a widget is constructed."""
        class_name = str(msg['content']['data']['widget_class'])
        if class_name in Widget.widget_types:
            widget_class = Widget.widget_types[class_name]
        else:
            widget_class = import_item(class_name)
        widget = widget_class(comm=comm)

    @staticmethod
    def get_manager_state(drop_defaults=False):
        return dict(
            version_major=1,
            version_minor=0,
            state={
                k: {
                    'model_name':
                    Widget.widgets[k]._model_name,
                    'model_module':
                    Widget.widgets[k]._model_module,
                    'model_module_version':
                    Widget.widgets[k]._model_module_version,
                    'state':
                    Widget.widgets[k].get_state(drop_defaults=drop_defaults)
                }
                for k in Widget.widgets
            })

    def get_view_spec(self):
        return dict(version_major=1, version_minor=0, model_id=self._model_id)

    #-------------------------------------------------------------------------
    # Traits
    #-------------------------------------------------------------------------
    _model_module = Unicode(
        'jupyter-js-widgets',
        help="A JavaScript module name in which to find _model_name.").tag(
            sync=True)
    _model_name = Unicode(
        'WidgetModel',
        help="Name of the model object in the front-end.").tag(sync=True)
    _model_module_version = Unicode(
        '*', help="A semver requirement for the model module version.").tag(
            sync=True)
    _view_module = Unicode(
        None,
        allow_none=True,
        help="A JavaScript module in which to find _view_name.").tag(sync=True)
    _view_name = Unicode(None,
                         allow_none=True,
                         help="Name of the view object.").tag(sync=True)
    _view_module_version = Unicode(
        '*', help="A semver requirement for the view module version.").tag(
            sync=True)
    comm = Instance('ipykernel.comm.Comm', allow_none=True)

    msg_throttle = Int(
        1,
        help=
        """Maximum number of msgs the front-end can send before receiving an idle msg from the back-end."""
    ).tag(sync=True)

    keys = List()

    def _keys_default(self):
        return [name for name in self.traits(sync=True)]

    _property_lock = Dict()
    _holding_sync = False
    _states_to_send = Set()
    _display_callbacks = Instance(CallbackDispatcher, ())
    _msg_callbacks = Instance(CallbackDispatcher, ())

    #-------------------------------------------------------------------------
    # (Con/de)structor
    #-------------------------------------------------------------------------
    def __init__(self, **kwargs):
        """Public constructor"""
        self._model_id = kwargs.pop('model_id', None)
        super(Widget, self).__init__(**kwargs)

        Widget._call_widget_constructed(self)
        self.open()

    def __del__(self):
        """Object disposal"""
        self.close()

    #-------------------------------------------------------------------------
    # Properties
    #-------------------------------------------------------------------------

    def open(self):
        """Open a comm to the frontend if one isn't already open."""
        if self.comm is None:
            state, buffer_keys, buffers = self._split_state_buffers(
                self.get_state())

            args = dict(target_name='jupyter.widget', data=state)
            if self._model_id is not None:
                args['comm_id'] = self._model_id

            self.comm = Comm(**args)
            if buffers:
                # FIXME: workaround ipykernel missing binary message support in open-on-init
                # send state with binary elements as second message
                self.send_state()

    @observe('comm')
    def _comm_changed(self, change):
        """Called when the comm is changed."""
        if change['new'] is None:
            return
        self._model_id = self.model_id

        self.comm.on_msg(self._handle_msg)
        Widget.widgets[self.model_id] = self

    @property
    def model_id(self):
        """Gets the model id of this widget.

        If a Comm doesn't exist yet, a Comm will be created automagically."""
        return self.comm.comm_id

    #-------------------------------------------------------------------------
    # Methods
    #-------------------------------------------------------------------------

    def close(self):
        """Close method.

        Closes the underlying comm.
        When the comm is closed, all of the widget views are automatically
        removed from the front-end."""
        if self.comm is not None:
            Widget.widgets.pop(self.model_id, None)
            self.comm.close()
            self.comm = None
            self._ipython_display_ = None

    def _split_state_buffers(self, state):
        """Return (state_without_buffers, buffer_keys, buffers) for binary message parts"""
        buffer_keys, buffers = [], []
        for k, v in list(state.items()):
            if isinstance(v, _binary_types):
                state.pop(k)
                buffers.append(v)
                buffer_keys.append(k)
        return state, buffer_keys, buffers

    def send_state(self, key=None):
        """Sends the widget state, or a piece of it, to the front-end.

        Parameters
        ----------
        key : unicode, or iterable (optional)
            A single property's name or iterable of property names to sync with the front-end.
        """
        state = self.get_state(key=key)
        state, buffer_keys, buffers = self._split_state_buffers(state)
        msg = {'method': 'update', 'state': state, 'buffers': buffer_keys}
        self._send(msg, buffers=buffers)

    def get_state(self, key=None, drop_defaults=False):
        """Gets the widget state, or a piece of it.

        Parameters
        ----------
        key : unicode or iterable (optional)
            A single property's name or iterable of property names to get.

        Returns
        -------
        state : dict of states
        metadata : dict
            metadata for each field: {key: metadata}
        """
        if key is None:
            keys = self.keys
        elif isinstance(key, string_types):
            keys = [key]
        elif isinstance(key, collections.Iterable):
            keys = key
        else:
            raise ValueError(
                "key must be a string, an iterable of keys, or None")
        state = {}
        traits = self.traits()
        for k in keys:
            to_json = self.trait_metadata(k, 'to_json', self._trait_to_json)
            value = to_json(getattr(self, k), self)
            if not PY3 and isinstance(traits[k], Bytes) and isinstance(
                    value, bytes):
                value = memoryview(value)
            if not drop_defaults or value != traits[k].default_value:
                state[k] = value
        return state

    def set_state(self, sync_data):
        """Called when a state is received from the front-end."""
        # The order of these context managers is important. Properties must
        # be locked when the hold_trait_notification context manager is
        # released and notifications are fired.
        with self._lock_property(**sync_data), self.hold_trait_notifications():
            for name in sync_data:
                if name in self.keys:
                    from_json = self.trait_metadata(name, 'from_json',
                                                    self._trait_from_json)
                    self.set_trait(name, from_json(sync_data[name], self))

    def send(self, content, buffers=None):
        """Sends a custom msg to the widget model in the front-end.

        Parameters
        ----------
        content : dict
            Content of the message to send.
        buffers : list of binary buffers
            Binary buffers to send with message
        """
        self._send({"method": "custom", "content": content}, buffers=buffers)

    def on_msg(self, callback, remove=False):
        """(Un)Register a custom msg receive callback.

        Parameters
        ----------
        callback: callable
            callback will be passed three arguments when a message arrives::

                callback(widget, content, buffers)

        remove: bool
            True if the callback should be unregistered."""
        self._msg_callbacks.register_callback(callback, remove=remove)

    def on_displayed(self, callback, remove=False):
        """(Un)Register a widget displayed callback.

        Parameters
        ----------
        callback: method handler
            Must have a signature of::

                callback(widget, **kwargs)

            kwargs from display are passed through without modification.
        remove: bool
            True if the callback should be unregistered."""
        self._display_callbacks.register_callback(callback, remove=remove)

    def add_traits(self, **traits):
        """Dynamically add trait attributes to the Widget."""
        super(Widget, self).add_traits(**traits)
        for name, trait in traits.items():
            if trait.get_metadata('sync'):
                self.keys.append(name)
                self.send_state(name)

    def notify_change(self, change):
        """Called when a property has changed."""
        # Send the state to the frontend before the user-registered callbacks
        # are called.
        name = change['name']
        if self.comm is not None and self.comm.kernel is not None:
            # Make sure this isn't information that the front-end just sent us.
            if name in self.keys and self._should_send_property(
                    name, change['new']):
                # Send new state to front-end
                self.send_state(key=name)
        LoggingConfigurable.notify_change(self, change)

    #-------------------------------------------------------------------------
    # Support methods
    #-------------------------------------------------------------------------
    @contextmanager
    def _lock_property(self, **properties):
        """Lock a property-value pair.

        The value should be the JSON state of the property.

        NOTE: This, in addition to the single lock for all state changes, is
        flawed.  In the future we may want to look into buffering state changes
        back to the front-end."""
        self._property_lock = properties
        try:
            yield
        finally:
            self._property_lock = {}

    @contextmanager
    def hold_sync(self):
        """Hold syncing any state until the outermost context manager exits"""
        if self._holding_sync is True:
            yield
        else:
            try:
                self._holding_sync = True
                yield
            finally:
                self._holding_sync = False
                self.send_state(self._states_to_send)
                self._states_to_send.clear()

    def _should_send_property(self, key, value):
        """Check the property lock (property_lock)"""
        to_json = self.trait_metadata(key, 'to_json', self._trait_to_json)
        if (key in self._property_lock
                and to_json(value, self) == self._property_lock[key]):
            return False
        elif self._holding_sync:
            self._states_to_send.add(key)
            return False
        else:
            return True

    # Event handlers
    @_show_traceback
    def _handle_msg(self, msg):
        """Called when a msg is received from the front-end"""
        data = msg['content']['data']
        method = data['method']

        # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one.
        if method == 'backbone':
            if 'sync_data' in data:
                # get binary buffers too
                sync_data = data['sync_data']
                for i, k in enumerate(data.get('buffer_keys', [])):
                    sync_data[k] = msg['buffers'][i]
                self.set_state(sync_data)  # handles all methods

        # Handle a state request.
        elif method == 'request_state':
            self.send_state()

        # Handle a custom msg from the front-end.
        elif method == 'custom':
            if 'content' in data:
                self._handle_custom_msg(data['content'], msg['buffers'])

        # Catch remainder.
        else:
            self.log.error(
                'Unknown front-end to back-end widget msg with method "%s"' %
                method)

    def _handle_custom_msg(self, content, buffers):
        """Called when a custom msg is received."""
        self._msg_callbacks(self, content, buffers)

    def _handle_displayed(self, **kwargs):
        """Called when a view has been displayed for this widget instance"""
        self._display_callbacks(self, **kwargs)

    @staticmethod
    def _trait_to_json(x, self):
        """Convert a trait value to json."""
        return x

    @staticmethod
    def _trait_from_json(x, self):
        """Convert json values to objects."""
        return x

    def _ipython_display_(self, **kwargs):
        """Called when `IPython.display.display` is called on the widget."""
        def loud_error(message):
            self.log.warn(message)
            sys.stderr.write('%s\n' % message)

        # Show view.
        if self._view_name is not None:
            validated = Widget._version_validated

            # Before the user tries to display a widget, validate that the
            # widget front-end is what is expected.
            if validated is None:
                loud_error('Widget Javascript not detected.  It may not be '
                           'installed or enabled properly.')
            elif not validated:
                msg = ('The installed widget Javascript is the wrong version.'
                       ' It must satisfy the semver range %s.' %
                       __frontend_version__)
                if (Widget._version_frontend):
                    msg += ' The widget Javascript is version %s.' % Widget._version_frontend
                loud_error(msg)

            # TODO: delete this sending of a comm message when the display statement
            # below works. Then add a 'text/plain' mimetype to the dictionary below.
            self._send({"method": "display"})

            # The 'application/vnd.jupyter.widget-view+json' mimetype has not been registered yet.
            # See the registration process and naming convention at
            # http://tools.ietf.org/html/rfc6838
            # and the currently registered mimetypes at
            # http://www.iana.org/assignments/media-types/media-types.xhtml.
            # We don't have a 'text/plain' entry, so this display message will be
            # will be invisible in the current notebook.
            data = {
                'application/vnd.jupyter.widget-view+json': {
                    'model_id': self._model_id
                }
            }
            display(data, raw=True)

            self._handle_displayed(**kwargs)

    def _send(self, msg, buffers=None):
        """Sends a message to the model in the front-end."""
        if self.comm is not None and self.comm.kernel is not None:
            self.comm.send(data=msg, buffers=buffers)
Example #27
0
class Kernel(SingletonConfigurable):

    #---------------------------------------------------------------------------
    # Kernel interface
    #---------------------------------------------------------------------------

    # attribute to override with a GUI
    eventloop = Any(None)

    @observe('eventloop')
    def _update_eventloop(self, change):
        """schedule call to eventloop from IOLoop"""
        loop = ioloop.IOLoop.current()
        if change.new is not None:
            loop.add_callback(self.enter_eventloop)

    session = Instance(Session, allow_none=True)
    profile_dir = Instance('IPython.core.profiledir.ProfileDir',
                           allow_none=True)
    shell_streams = List()
    control_stream = Instance(ZMQStream, allow_none=True)
    iopub_socket = Any()
    iopub_thread = Any()
    stdin_socket = Any()
    log = Instance(logging.Logger, allow_none=True)

    # identities:
    int_id = Integer(-1)
    ident = Unicode()

    @default('ident')
    def _default_ident(self):
        return unicode_type(uuid.uuid4())

    # This should be overridden by wrapper kernels that implement any real
    # language.
    language_info = {}

    # any links that should go in the help menu
    help_links = List()

    # Private interface

    _darwin_app_nap = Bool(
        True,
        help="""Whether to use appnope for compatibility with OS X App Nap.

        Only affects OS X >= 10.9.
        """).tag(config=True)

    # track associations with current request
    _allow_stdin = Bool(False)
    _parent_header = Dict()
    _parent_ident = Any(b'')
    # Time to sleep after flushing the stdout/err buffers in each execute
    # cycle.  While this introduces a hard limit on the minimal latency of the
    # execute cycle, it helps prevent output synchronization problems for
    # clients.
    # Units are in seconds.  The minimum zmq latency on local host is probably
    # ~150 microseconds, set this to 500us for now.  We may need to increase it
    # a little if it's not enough after more interactive testing.
    _execute_sleep = Float(0.0005).tag(config=True)

    # Frequency of the kernel's event loop.
    # Units are in seconds, kernel subclasses for GUI toolkits may need to
    # adapt to milliseconds.
    _poll_interval = Float(0.01).tag(config=True)

    stop_on_error_timeout = Float(
        0.1,
        config=True,
        help="""time (in seconds) to wait for messages to arrive
        when aborting queued requests after an error.

        Requests that arrive within this window after an error
        will be cancelled.

        Increase in the event of unusually slow network
        causing significant delays,
        which can manifest as e.g. "Run all" in a notebook
        aborting some, but not all, messages after an error.
        """)

    # If the shutdown was requested over the network, we leave here the
    # necessary reply message so it can be sent by our registered atexit
    # handler.  This ensures that the reply is only sent to clients truly at
    # the end of our shutdown process (which happens after the underlying
    # IPython shell's own shutdown).
    _shutdown_message = None

    # This is a dict of port number that the kernel is listening on. It is set
    # by record_ports and used by connect_request.
    _recorded_ports = Dict()

    # set of aborted msg_ids
    aborted = Set()

    # Track execution count here. For IPython, we override this to use the
    # execution count we store in the shell.
    execution_count = 0

    msg_types = [
        'execute_request',
        'complete_request',
        'inspect_request',
        'history_request',
        'comm_info_request',
        'kernel_info_request',
        'connect_request',
        'shutdown_request',
        'is_complete_request',
        # deprecated:
        'apply_request',
    ]
    # add deprecated ipyparallel control messages
    control_msg_types = msg_types + ['clear_request', 'abort_request']

    def __init__(self, **kwargs):
        super(Kernel, self).__init__(**kwargs)
        # Build dict of handlers for message types
        self.shell_handlers = {}
        for msg_type in self.msg_types:
            self.shell_handlers[msg_type] = getattr(self, msg_type)

        self.control_handlers = {}
        for msg_type in self.control_msg_types:
            self.control_handlers[msg_type] = getattr(self, msg_type)

    @gen.coroutine
    def dispatch_control(self, msg):
        """dispatch control requests"""
        idents, msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.deserialize(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Control Message", exc_info=True)
            return

        self.log.debug("Control received: %s", msg)

        # Set the parent message for side effects.
        self.set_parent(idents, msg)
        self._publish_status(u'busy')
        if self._aborting:
            self._send_abort_reply(self.control_stream, msg, idents)
            self._publish_status(u'idle')
            return

        header = msg['header']
        msg_type = header['msg_type']

        handler = self.control_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type)
        else:
            try:
                yield gen.maybe_future(
                    handler(self.control_stream, idents, msg))
            except Exception:
                self.log.error("Exception in control handler:", exc_info=True)

        sys.stdout.flush()
        sys.stderr.flush()
        self._publish_status(u'idle')
        # flush to ensure reply is sent
        self.control_stream.flush(zmq.POLLOUT)

    def should_handle(self, stream, msg, idents):
        """Check whether a shell-channel message should be handled

        Allows subclasses to prevent handling of certain messages (e.g. aborted requests).
        """
        msg_id = msg['header']['msg_id']
        if msg_id in self.aborted:
            msg_type = msg['header']['msg_type']
            # is it safe to assume a msg_id will not be resubmitted?
            self.aborted.remove(msg_id)
            self._send_abort_reply(stream, msg, idents)
            return False
        return True

    @gen.coroutine
    def dispatch_shell(self, stream, msg):
        """dispatch shell requests"""
        idents, msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.deserialize(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Message", exc_info=True)
            return

        # Set the parent message for side effects.
        self.set_parent(idents, msg)
        self._publish_status(u'busy')

        if self._aborting:
            self._send_abort_reply(stream, msg, idents)
            self._publish_status(u'idle')
            # flush to ensure reply is sent before
            # handling the next request
            stream.flush(zmq.POLLOUT)
            return

        msg_type = msg['header']['msg_type']

        # Print some info about this message and leave a '--->' marker, so it's
        # easier to trace visually the message chain when debugging.  Each
        # handler prints its message at the end.
        self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type)
        self.log.debug('   Content: %s\n   --->\n   ', msg['content'])

        if not self.should_handle(stream, msg, idents):
            return

        handler = self.shell_handlers.get(msg_type, None)
        if handler is None:
            self.log.warning("Unknown message type: %r", msg_type)
        else:
            self.log.debug("%s: %s", msg_type, msg)
            try:
                self.pre_handler_hook()
            except Exception:
                self.log.debug("Unable to signal in pre_handler_hook:",
                               exc_info=True)
            try:
                yield gen.maybe_future(handler(stream, idents, msg))
            except Exception:
                self.log.error("Exception in message handler:", exc_info=True)
            finally:
                try:
                    self.post_handler_hook()
                except Exception:
                    self.log.debug("Unable to signal in post_handler_hook:",
                                   exc_info=True)

        sys.stdout.flush()
        sys.stderr.flush()
        self._publish_status(u'idle')
        # flush to ensure reply is sent before
        # handling the next request
        stream.flush(zmq.POLLOUT)

    def pre_handler_hook(self):
        """Hook to execute before calling message handler"""
        # ensure default_int_handler during handler call
        self.saved_sigint_handler = signal(SIGINT, default_int_handler)

    def post_handler_hook(self):
        """Hook to execute after calling message handler"""
        signal(SIGINT, self.saved_sigint_handler)

    def enter_eventloop(self):
        """enter eventloop"""
        self.log.info("Entering eventloop %s", self.eventloop)
        # record handle, so we can check when this changes
        eventloop = self.eventloop
        if eventloop is None:
            self.log.info("Exiting as there is no eventloop")
            return

        def advance_eventloop():
            # check if eventloop changed:
            if self.eventloop is not eventloop:
                self.log.info("exiting eventloop %s", eventloop)
                return
            if self.msg_queue.qsize():
                self.log.debug("Delaying eventloop due to waiting messages")
                # still messages to process, make the eventloop wait
                schedule_next()
                return
            self.log.debug("Advancing eventloop %s", eventloop)
            try:
                eventloop(self)
            except KeyboardInterrupt:
                # Ctrl-C shouldn't crash the kernel
                self.log.error("KeyboardInterrupt caught in kernel")
                pass
            if self.eventloop is eventloop:
                # schedule advance again
                schedule_next()

        def schedule_next():
            """Schedule the next advance of the eventloop"""
            # flush the eventloop every so often,
            # giving us a chance to handle messages in the meantime
            self.log.debug("Scheduling eventloop advance")
            self.io_loop.call_later(1, advance_eventloop)

        # begin polling the eventloop
        schedule_next()

    @gen.coroutine
    def do_one_iteration(self):
        """Process a single shell message

        Any pending control messages will be flushed as well

        .. versionchanged:: 5
            This is now a coroutine
        """
        # flush messages off of shell streams into the message queue
        for stream in self.shell_streams:
            stream.flush()
        # process all messages higher priority than shell (control),
        # and at most one shell message per iteration
        priority = 0
        while priority is not None and priority < SHELL_PRIORITY:
            priority = yield self.process_one(wait=False)

    @gen.coroutine
    def process_one(self, wait=True):
        """Process one request

        Returns priority of the message handled.
        Returns None if no message was handled.
        """
        if wait:
            priority, t, dispatch, args = yield self.msg_queue.get()
        else:
            try:
                priority, t, dispatch, args = self.msg_queue.get_nowait()
            except QueueEmpty:
                return None
        yield gen.maybe_future(dispatch(*args))

    @gen.coroutine
    def dispatch_queue(self):
        """Coroutine to preserve order of message handling

        Ensures that only one message is processing at a time,
        even when the handler is async
        """

        while True:
            # ensure control stream is flushed before processing shell messages
            if self.control_stream:
                self.control_stream.flush()
            # receive the next message and handle it
            try:
                yield self.process_one()
            except Exception:
                self.log.exception("Error in message handler")

    _message_counter = Any(help="""Monotonic counter of messages

        Ensures messages of the same priority are handled in arrival order.
        """, )

    @default('_message_counter')
    def _message_counter_default(self):
        return itertools.count()

    def schedule_dispatch(self, priority, dispatch, *args):
        """schedule a message for dispatch"""
        idx = next(self._message_counter)

        self.msg_queue.put_nowait((
            priority,
            idx,
            dispatch,
            args,
        ))
        # ensure the eventloop wakes up
        self.io_loop.add_callback(lambda: None)

    def start(self):
        """register dispatchers for streams"""
        self.io_loop = ioloop.IOLoop.current()
        self.msg_queue = PriorityQueue()
        self.io_loop.add_callback(self.dispatch_queue)

        if self.control_stream:
            self.control_stream.on_recv(
                partial(
                    self.schedule_dispatch,
                    CONTROL_PRIORITY,
                    self.dispatch_control,
                ),
                copy=False,
            )

        for s in self.shell_streams:
            if s is self.control_stream:
                continue
            s.on_recv(
                partial(
                    self.schedule_dispatch,
                    SHELL_PRIORITY,
                    self.dispatch_shell,
                    s,
                ),
                copy=False,
            )

        # publish idle status
        self._publish_status('starting')

    def record_ports(self, ports):
        """Record the ports that this kernel is using.

        The creator of the Kernel instance must call this methods if they
        want the :meth:`connect_request` method to return the port numbers.
        """
        self._recorded_ports = ports

    #---------------------------------------------------------------------------
    # Kernel request handlers
    #---------------------------------------------------------------------------

    def _publish_execute_input(self, code, parent, execution_count):
        """Publish the code request on the iopub stream."""

        self.session.send(self.iopub_socket,
                          u'execute_input', {
                              u'code': code,
                              u'execution_count': execution_count
                          },
                          parent=parent,
                          ident=self._topic('execute_input'))

    def _publish_status(self, status, parent=None):
        """send status (busy/idle) on IOPub"""
        self.session.send(
            self.iopub_socket,
            u'status',
            {u'execution_state': status},
            parent=parent or self._parent_header,
            ident=self._topic('status'),
        )

    def set_parent(self, ident, parent):
        """Set the current parent_header

        Side effects (IOPub messages) and replies are associated with
        the request that caused them via the parent_header.

        The parent identity is used to route input_request messages
        on the stdin channel.
        """
        self._parent_ident = ident
        self._parent_header = parent

    def send_response(self,
                      stream,
                      msg_or_type,
                      content=None,
                      ident=None,
                      buffers=None,
                      track=False,
                      header=None,
                      metadata=None):
        """Send a response to the message we're currently processing.

        This accepts all the parameters of :meth:`jupyter_client.session.Session.send`
        except ``parent``.

        This relies on :meth:`set_parent` having been called for the current
        message.
        """
        return self.session.send(stream, msg_or_type, content,
                                 self._parent_header, ident, buffers, track,
                                 header, metadata)

    def init_metadata(self, parent):
        """Initialize metadata.

        Run at the beginning of execution requests.
        """
        # FIXME: `started` is part of ipyparallel
        # Remove for ipykernel 5.0
        return {
            'started': now(),
        }

    def finish_metadata(self, parent, metadata, reply_content):
        """Finish populating metadata.

        Run after completing an execution request.
        """
        return metadata

    @gen.coroutine
    def execute_request(self, stream, ident, parent):
        """handle an execute_request"""

        try:
            content = parent[u'content']
            code = py3compat.cast_unicode_py2(content[u'code'])
            silent = content[u'silent']
            store_history = content.get(u'store_history', not silent)
            user_expressions = content.get('user_expressions', {})
            allow_stdin = content.get('allow_stdin', False)
        except:
            self.log.error("Got bad msg: ")
            self.log.error("%s", parent)
            return

        stop_on_error = content.get('stop_on_error', True)

        metadata = self.init_metadata(parent)

        # Re-broadcast our input for the benefit of listening clients, and
        # start computing output
        if not silent:
            self.execution_count += 1
            self._publish_execute_input(code, parent, self.execution_count)

        reply_content = yield gen.maybe_future(
            self.do_execute(
                code,
                silent,
                store_history,
                user_expressions,
                allow_stdin,
            ))

        # Flush output before sending the reply.
        sys.stdout.flush()
        sys.stderr.flush()
        # FIXME: on rare occasions, the flush doesn't seem to make it to the
        # clients... This seems to mitigate the problem, but we definitely need
        # to better understand what's going on.
        if self._execute_sleep:
            time.sleep(self._execute_sleep)

        # Send the reply.
        reply_content = json_clean(reply_content)
        metadata = self.finish_metadata(parent, metadata, reply_content)

        reply_msg = self.session.send(stream,
                                      u'execute_reply',
                                      reply_content,
                                      parent,
                                      metadata=metadata,
                                      ident=ident)

        self.log.debug("%s", reply_msg)

        if not silent and reply_msg['content'][
                'status'] == u'error' and stop_on_error:
            yield self._abort_queues()

    def do_execute(self,
                   code,
                   silent,
                   store_history=True,
                   user_expressions=None,
                   allow_stdin=False):
        """Execute user code. Must be overridden by subclasses.
        """
        raise NotImplementedError

    @gen.coroutine
    def complete_request(self, stream, ident, parent):
        content = parent['content']
        code = content['code']
        cursor_pos = content['cursor_pos']

        matches = yield gen.maybe_future(self.do_complete(code, cursor_pos))
        matches = json_clean(matches)
        completion_msg = self.session.send(stream, 'complete_reply', matches,
                                           parent, ident)

    def do_complete(self, code, cursor_pos):
        """Override in subclasses to find completions.
        """
        return {
            'matches': [],
            'cursor_end': cursor_pos,
            'cursor_start': cursor_pos,
            'metadata': {},
            'status': 'ok'
        }

    @gen.coroutine
    def inspect_request(self, stream, ident, parent):
        content = parent['content']

        reply_content = yield gen.maybe_future(
            self.do_inspect(
                content['code'],
                content['cursor_pos'],
                content.get('detail_level', 0),
            ))
        # Before we send this object over, we scrub it for JSON usage
        reply_content = json_clean(reply_content)
        msg = self.session.send(stream, 'inspect_reply', reply_content, parent,
                                ident)
        self.log.debug("%s", msg)

    def do_inspect(self, code, cursor_pos, detail_level=0):
        """Override in subclasses to allow introspection.
        """
        return {'status': 'ok', 'data': {}, 'metadata': {}, 'found': False}

    @gen.coroutine
    def history_request(self, stream, ident, parent):
        content = parent['content']

        reply_content = yield gen.maybe_future(self.do_history(**content))

        reply_content = json_clean(reply_content)
        msg = self.session.send(stream, 'history_reply', reply_content, parent,
                                ident)
        self.log.debug("%s", msg)

    def do_history(self,
                   hist_access_type,
                   output,
                   raw,
                   session=None,
                   start=None,
                   stop=None,
                   n=None,
                   pattern=None,
                   unique=False):
        """Override in subclasses to access history.
        """
        return {'status': 'ok', 'history': []}

    def connect_request(self, stream, ident, parent):
        if self._recorded_ports is not None:
            content = self._recorded_ports.copy()
        else:
            content = {}
        content['status'] = 'ok'
        msg = self.session.send(stream, 'connect_reply', content, parent,
                                ident)
        self.log.debug("%s", msg)

    @property
    def kernel_info(self):
        return {
            'protocol_version': kernel_protocol_version,
            'implementation': self.implementation,
            'implementation_version': self.implementation_version,
            'language_info': self.language_info,
            'banner': self.banner,
            'help_links': self.help_links,
        }

    def kernel_info_request(self, stream, ident, parent):
        content = {'status': 'ok'}
        content.update(self.kernel_info)
        msg = self.session.send(stream, 'kernel_info_reply', content, parent,
                                ident)
        self.log.debug("%s", msg)

    def comm_info_request(self, stream, ident, parent):
        content = parent['content']
        target_name = content.get('target_name', None)

        # Should this be moved to ipkernel?
        if hasattr(self, 'comm_manager'):
            comms = {
                k: dict(target_name=v.target_name)
                for (k, v) in self.comm_manager.comms.items()
                if v.target_name == target_name or target_name is None
            }
        else:
            comms = {}
        reply_content = dict(comms=comms, status='ok')
        msg = self.session.send(stream, 'comm_info_reply', reply_content,
                                parent, ident)
        self.log.debug("%s", msg)

    @gen.coroutine
    def shutdown_request(self, stream, ident, parent):
        content = yield gen.maybe_future(
            self.do_shutdown(parent['content']['restart']))
        self.session.send(stream,
                          u'shutdown_reply',
                          content,
                          parent,
                          ident=ident)
        # same content, but different msg_id for broadcasting on IOPub
        self._shutdown_message = self.session.msg(u'shutdown_reply', content,
                                                  parent)

        self._at_shutdown()
        # call sys.exit after a short delay
        loop = ioloop.IOLoop.current()
        loop.add_timeout(time.time() + 0.1, loop.stop)

    def do_shutdown(self, restart):
        """Override in subclasses to do things when the frontend shuts down the
        kernel.
        """
        return {'status': 'ok', 'restart': restart}

    @gen.coroutine
    def is_complete_request(self, stream, ident, parent):
        content = parent['content']
        code = content['code']

        reply_content = yield gen.maybe_future(self.do_is_complete(code))
        reply_content = json_clean(reply_content)
        reply_msg = self.session.send(stream, 'is_complete_reply',
                                      reply_content, parent, ident)
        self.log.debug("%s", reply_msg)

    def do_is_complete(self, code):
        """Override in subclasses to find completions.
        """
        return {
            'status': 'unknown',
        }

    #---------------------------------------------------------------------------
    # Engine methods (DEPRECATED)
    #---------------------------------------------------------------------------

    def apply_request(self, stream, ident, parent):
        self.log.warning(
            "apply_request is deprecated in kernel_base, moving to ipyparallel."
        )
        try:
            content = parent[u'content']
            bufs = parent[u'buffers']
            msg_id = parent['header']['msg_id']
        except:
            self.log.error("Got bad msg: %s", parent, exc_info=True)
            return

        md = self.init_metadata(parent)

        reply_content, result_buf = self.do_apply(content, bufs, msg_id, md)

        # flush i/o
        sys.stdout.flush()
        sys.stderr.flush()

        md = self.finish_metadata(parent, md, reply_content)

        self.session.send(stream,
                          u'apply_reply',
                          reply_content,
                          parent=parent,
                          ident=ident,
                          buffers=result_buf,
                          metadata=md)

    def do_apply(self, content, bufs, msg_id, reply_metadata):
        """DEPRECATED"""
        raise NotImplementedError

    #---------------------------------------------------------------------------
    # Control messages (DEPRECATED)
    #---------------------------------------------------------------------------

    def abort_request(self, stream, ident, parent):
        """abort a specific msg by id"""
        self.log.warning(
            "abort_request is deprecated in kernel_base. It is only part of IPython parallel"
        )
        msg_ids = parent['content'].get('msg_ids', None)
        if isinstance(msg_ids, string_types):
            msg_ids = [msg_ids]
        if not msg_ids:
            self._abort_queues()
        for mid in msg_ids:
            self.aborted.add(str(mid))

        content = dict(status='ok')
        reply_msg = self.session.send(stream,
                                      'abort_reply',
                                      content=content,
                                      parent=parent,
                                      ident=ident)
        self.log.debug("%s", reply_msg)

    def clear_request(self, stream, idents, parent):
        """Clear our namespace."""
        self.log.warning(
            "clear_request is deprecated in kernel_base. It is only part of IPython parallel"
        )
        content = self.do_clear()
        self.session.send(stream,
                          'clear_reply',
                          ident=idents,
                          parent=parent,
                          content=content)

    def do_clear(self):
        """DEPRECATED since 4.0.3"""
        raise NotImplementedError

    #---------------------------------------------------------------------------
    # Protected interface
    #---------------------------------------------------------------------------

    def _topic(self, topic):
        """prefixed topic for IOPub messages"""
        base = "kernel.%s" % self.ident

        return py3compat.cast_bytes("%s.%s" % (base, topic))

    _aborting = Bool(False)

    @gen.coroutine
    def _abort_queues(self):
        for stream in self.shell_streams:
            stream.flush()
        self._aborting = True

        self.schedule_dispatch(
            ABORT_PRIORITY,
            self._dispatch_abort,
        )

    @gen.coroutine
    def _dispatch_abort(self):
        self.log.info("Finishing abort")
        yield gen.sleep(self.stop_on_error_timeout)
        self._aborting = False

    def _send_abort_reply(self, stream, msg, idents):
        """Send a reply to an aborted request"""
        self.log.info("Aborting:")
        self.log.info("%s", msg)
        reply_type = msg['header']['msg_type'].rsplit('_', 1)[0] + '_reply'
        status = {'status': 'aborted'}
        md = {'engine': self.ident}
        md.update(status)
        self.session.send(
            stream,
            reply_type,
            metadata=md,
            content=status,
            parent=msg,
            ident=idents,
        )

    def _no_raw_input(self):
        """Raise StdinNotImplentedError if active frontend doesn't support
        stdin."""
        raise StdinNotImplementedError("raw_input was called, but this "
                                       "frontend does not support stdin.")

    def getpass(self, prompt='', stream=None):
        """Forward getpass to frontends

        Raises
        ------
        StdinNotImplentedError if active frontend doesn't support stdin.
        """
        if not self._allow_stdin:
            raise StdinNotImplementedError(
                "getpass was called, but this frontend does not support input requests."
            )
        if stream is not None:
            import warnings
            warnings.warn(
                "The `stream` parameter of `getpass.getpass` will have no effect when using ipykernel",
                UserWarning,
                stacklevel=2)
        return self._input_request(
            prompt,
            self._parent_ident,
            self._parent_header,
            password=True,
        )

    def raw_input(self, prompt=''):
        """Forward raw_input to frontends

        Raises
        ------
        StdinNotImplentedError if active frontend doesn't support stdin.
        """
        if not self._allow_stdin:
            raise StdinNotImplementedError(
                "raw_input was called, but this frontend does not support input requests."
            )
        return self._input_request(
            str(prompt),
            self._parent_ident,
            self._parent_header,
            password=False,
        )

    def _input_request(self, prompt, ident, parent, password=False):
        # Flush output before making the request.
        sys.stderr.flush()
        sys.stdout.flush()
        # flush the stdin socket, to purge stale replies
        while True:
            try:
                self.stdin_socket.recv_multipart(zmq.NOBLOCK)
            except zmq.ZMQError as e:
                if e.errno == zmq.EAGAIN:
                    break
                else:
                    raise

        # Send the input request.
        content = json_clean(dict(prompt=prompt, password=password))
        self.session.send(self.stdin_socket,
                          u'input_request',
                          content,
                          parent,
                          ident=ident)

        # Await a response.
        while True:
            try:
                ident, reply = self.session.recv(self.stdin_socket, 0)
            except Exception:
                self.log.warning("Invalid Message:", exc_info=True)
            except KeyboardInterrupt:
                # re-raise KeyboardInterrupt, to truncate traceback
                raise KeyboardInterrupt("Interrupted by user") from None
            else:
                break
        try:
            value = py3compat.unicode_to_str(reply['content']['value'])
        except:
            self.log.error("Bad input_reply: %s", parent)
            value = ''
        if value == '\x04':
            # EOF
            raise EOFError
        return value

    def _at_shutdown(self):
        """Actions taken at shutdown by the kernel, called by python's atexit.
        """
        if self._shutdown_message is not None:
            self.session.send(self.iopub_socket,
                              self._shutdown_message,
                              ident=self._topic('shutdown'))
            self.log.debug("%s", self._shutdown_message)
        [s.flush(zmq.POLLOUT) for s in self.shell_streams]
Example #28
0
class GitLabOAuthenticator(OAuthenticator):

    login_service = "GitLab"

    client_id_env = 'GITLAB_CLIENT_ID'
    client_secret_env = 'GITLAB_CLIENT_SECRET'
    login_handler = GitLabLoginHandler

    gitlab_group_whitelist = Set(
        config=True,
        help="Automatically whitelist members of selected groups",
    )

    @gen.coroutine
    def authenticate(self, handler, data=None):
        code = handler.get_argument("code")
        # TODO: Configure the curl_httpclient for tornado
        http_client = AsyncHTTPClient()

        # Exchange the OAuth code for a GitLab Access Token
        #
        # See: https://github.com/gitlabhq/gitlabhq/blob/master/doc/api/oauth2.md

        # GitLab specifies a POST request yet requires URL parameters
        params = dict(
            client_id=self.client_id,
            client_secret=self.client_secret,
            code=code,
            grant_type="authorization_code",
            redirect_uri=self.get_callback_url(handler),
        )

        validate_server_cert = self.validate_server_cert

        url = url_concat("%s/oauth/token" % GITLAB_HOST, params)

        req = HTTPRequest(
            url,
            method="POST",
            headers={"Accept": "application/json"},
            validate_cert=validate_server_cert,
            body=''  # Body is required for a POST...
        )

        resp = yield http_client.fetch(req)
        resp_json = json.loads(resp.body.decode('utf8', 'replace'))

        access_token = resp_json['access_token']

        # Determine who the logged in user is
        req = HTTPRequest("%s/user" % GITLAB_API,
                          method="GET",
                          validate_cert=validate_server_cert,
                          headers=_api_headers(access_token))
        resp = yield http_client.fetch(req)
        resp_json = json.loads(resp.body.decode('utf8', 'replace'))

        username = resp_json["username"]
        user_id = resp_json["id"]
        is_admin = resp_json.get("is_admin", False)

        # Check if user is a member of any whitelisted organizations.
        # This check is performed here, as it requires `access_token`.
        if self.gitlab_group_whitelist:
            user_in_group = yield self._check_group_whitelist(
                username, user_id, is_admin, access_token)
            if not user_in_group:
                self.log.warning("%s not in group whitelist", username)
                return None
        return {
            'name': username,
            'auth_state': {
                'access_token': access_token,
                'gitlab_user': resp_json,
            }
        }

    @gen.coroutine
    def _check_group_whitelist(self, username, user_id, is_admin,
                               access_token):
        http_client = AsyncHTTPClient()
        headers = _api_headers(access_token)
        if is_admin:
            # For admins, /groups returns *all* groups. As a workaround
            # we check if we are a member of each group in the whitelist
            for group in map(url_escape, self.gitlab_group_whitelist):
                url = "%s/groups/%s/members/%d" % (GITLAB_API, group, user_id)
                req = HTTPRequest(url, method="GET", headers=headers)
                resp = yield http_client.fetch(req, raise_error=False)
                if resp.code == 200:
                    return True  # user _is_ in group
        else:
            # For regular users we get all the groups to which they have access
            # and check if any of these are in the whitelisted groups
            next_page = url_concat("%s/groups" % GITLAB_API,
                                   dict(all_available=True))
            while next_page:
                req = HTTPRequest(next_page, method="GET", headers=headers)
                resp = yield http_client.fetch(req)
                resp_json = json.loads(resp.body.decode('utf8', 'replace'))
                next_page = next_page_from_links(resp)
                user_groups = set(entry["path"] for entry in resp_json)
                # check if any of the organizations seen thus far are in whitelist
                if len(self.gitlab_group_whitelist & user_groups) > 0:
                    return True
            return False