Esempio n. 1
0
class OpenIDConnectFrontend(FrontendModule):
    """
    A OpenID Connect frontend module
    """

    def __init__(self, auth_req_callback_func, internal_attributes, conf, base_url, name):
        self._validate_config(conf)
        super().__init__(auth_req_callback_func, internal_attributes, base_url, name)

        self.config = conf
        self.signing_key = RSAKey(key=rsa_load(conf["signing_key_path"]), use="sig", alg="RS256")

    def _create_provider(self, endpoint_baseurl):
        response_types_supported = self.config["provider"].get("response_types_supported", ["id_token"])
        subject_types_supported = self.config["provider"].get("subject_types_supported", ["pairwise"])
        scopes_supported = self.config["provider"].get("scopes_supported", ["openid"])
        capabilities = {
            "issuer": self.base_url,
            "authorization_endpoint": "{}/{}".format(endpoint_baseurl, AuthorizationEndpoint.url),
            "jwks_uri": "{}/jwks".format(endpoint_baseurl),
            "response_types_supported": response_types_supported,
            "id_token_signing_alg_values_supported": [self.signing_key.alg],
            "response_modes_supported": ["fragment", "query"],
            "subject_types_supported": subject_types_supported,
            "claim_types_supported": ["normal"],
            "claims_parameter_supported": True,
            "claims_supported": [
                attribute_map["openid"][0]
                for attribute_map in self.internal_attributes["attributes"].values()
                if "openid" in attribute_map
            ],
            "request_parameter_supported": False,
            "request_uri_parameter_supported": False,
            "scopes_supported": scopes_supported,
        }

        if self.config["provider"].get("client_registration_supported", False):
            capabilities["registration_endpoint"] = "{}/{}".format(endpoint_baseurl, RegistrationEndpoint.url)

        authz_state = self._init_authorization_state()
        db_uri = self.config.get("db_uri")
        cdb = MongoWrapper(db_uri, "satosa", "clients") if db_uri else {}
        self.user_db = MongoWrapper(db_uri, "satosa", "authz_codes") if db_uri else {}
        self.provider = Provider(self.signing_key, capabilities, authz_state, cdb, Userinfo(self.user_db))

    def _init_authorization_state(self):
        sub_hash_salt = self.config.get("sub_hash_salt", rndstr(16))
        db_uri = self.config.get("db_uri")
        if db_uri:
            authz_code_db = MongoWrapper(db_uri, "satosa", "authz_codes")
            access_token_db = MongoWrapper(db_uri, "satosa", "access_tokens")
            refresh_token_db = MongoWrapper(db_uri, "satosa", "refresh_tokens")
            sub_db = MongoWrapper(db_uri, "satosa", "subject_identifiers")
        else:
            authz_code_db = None
            access_token_db = None
            refresh_token_db = None
            sub_db = None

        token_lifetimes = {
            k: self.config["provider"][k]
            for k in [
                "authorization_code_lifetime",
                "access_token_lifetime",
                "refresh_token_lifetime",
                "refresh_token_threshold",
            ]
            if k in self.config["provider"]
        }
        return AuthorizationState(
            HashBasedSubjectIdentifierFactory(sub_hash_salt),
            authz_code_db,
            access_token_db,
            refresh_token_db,
            sub_db,
            **token_lifetimes
        )

    def handle_authn_response(self, context, internal_resp, extra_id_token_claims=None):
        """
        See super class method satosa.frontends.base.FrontendModule#handle_authn_response
        :type context: satosa.context.Context
        :type internal_response: satosa.internal_data.InternalResponse
        :rtype oic.utils.http_util.Response
        """

        auth_req = self._get_authn_request_from_state(context.state)

        attributes = self.converter.from_internal("openid", internal_resp.attributes)
        self.user_db[internal_resp.user_id] = {k: v[0] for k, v in attributes.items()}
        auth_resp = self.provider.authorize(auth_req, internal_resp.user_id, extra_id_token_claims)

        del context.state[self.name]
        http_response = auth_resp.request(auth_req["redirect_uri"], should_fragment_encode(auth_req))
        return SeeOther(http_response)

    def handle_backend_error(self, exception):
        """
        See super class satosa.frontends.base.FrontendModule
        :type exception: satosa.exception.SATOSAError
        :rtype: oic.utils.http_util.Response
        """
        auth_req = self._get_authn_request_from_state(exception.state)
        error_resp = AuthorizationErrorResponse(error="access_denied", error_description=exception.message)
        satosa_logging(logger, logging.DEBUG, exception.message, exception.state)
        return SeeOther(error_resp.request(auth_req["redirect_uri"], should_fragment_encode(auth_req)))

    def register_endpoints(self, backend_names):
        """
        See super class satosa.frontends.base.FrontendModule
        :type backend_names: list[str]
        :rtype: list[(str, ((satosa.context.Context, Any) -> satosa.response.Response, Any))]
        :raise ValueError: if more than one backend is configured
        """
        backend_name = None
        if len(backend_names) != 1:
            # only supports one backend since there currently is no way to publish multiple authorization endpoints
            # in configuration information and there is no other standard way of authorization_endpoint discovery
            # similar to SAML entity discovery
            # this can be circumvented with a custom RequestMicroService which handles the routing based on something
            # in the authentication request
            logger.warn(
                "More than one backend is configured, make sure to provide a custom routing micro service to "
                "determine which backend should be used per request."
            )
        else:
            backend_name = backend_names[0]

        endpoint_baseurl = "{}/{}".format(self.base_url, self.name)
        self._create_provider(endpoint_baseurl)

        provider_config = ("^.well-known/openid-configuration$", self.provider_config)
        jwks_uri = ("^{}/jwks$".format(self.name), self.jwks)

        if backend_name:
            # if there is only one backend, include its name in the path so the default routing can work
            auth_endpoint = "{}/{}/{}/{}".format(self.base_url, backend_name, self.name, AuthorizationEndpoint.url)
            self.provider.configuration_information["authorization_endpoint"] = auth_endpoint
            auth_path = urlparse(auth_endpoint).path.lstrip("/")
        else:
            auth_path = "{}/{}".format(self.name, AuthorizationEndpoint.url)
        authentication = ("^{}$".format(auth_path), self.handle_authn_request)
        url_map = [provider_config, jwks_uri, authentication]

        if any("code" in v for v in self.provider.configuration_information["response_types_supported"]):
            self.provider.configuration_information["token_endpoint"] = "{}/{}".format(
                endpoint_baseurl, TokenEndpoint.url
            )
            token_endpoint = ("^{}/{}".format(self.name, TokenEndpoint.url), self.token_endpoint)
            url_map.append(token_endpoint)

            self.provider.configuration_information["userinfo_endpoint"] = "{}/{}".format(
                endpoint_baseurl, UserinfoEndpoint.url
            )
            userinfo_endpoint = ("^{}/{}".format(self.name, UserinfoEndpoint.url), self.userinfo_endpoint)
            url_map.append(userinfo_endpoint)
        if "registration_endpoint" in self.provider.configuration_information:
            client_registration = ("^{}/{}".format(self.name, RegistrationEndpoint.url), self.client_registration)
            url_map.append(client_registration)

        return url_map

    def _validate_config(self, config):
        """
        Validates that all necessary config parameters are specified.
        :type config: dict[str, dict[str, Any] | str]
        :param config: the module config
        """
        if config is None:
            raise ValueError("OIDCFrontend conf can't be 'None'.")

        for k in {"signing_key_path", "provider"}:
            if k not in config:
                raise ValueError("Missing configuration parameter '{}' for OpenID Connect frontend.".format(k))

    def _get_authn_request_from_state(self, state):
        """
        Extract the clietns request stoed in the SATOSA state.
        :type state: satosa.state.State
        :rtype: oic.oic.message.AuthorizationRequest

        :param state: the current state
        :return: the parsed authentication request
        """
        return AuthorizationRequest().deserialize(state[self.name]["oidc_request"])

    def client_registration(self, context):
        """
        Handle the OIDC dynamic client registration.
        :type context: satosa.context.Context
        :rtype: oic.utils.http_util.Response

        :param context: the current context
        :return: HTTP response to the client
        """
        try:
            resp = self.provider.handle_client_registration_request(json.dumps(context.request))
            return Created(resp.to_json(), content="application/json")
        except InvalidClientRegistrationRequest as e:
            return BadRequest(e.to_json(), content="application/json")

    def provider_config(self, context):
        """
        Construct the provider configuration information (served at /.well-known/openid-configuration).
        :type context: satosa.context.Context
        :rtype: oic.utils.http_util.Response

        :param context: the current context
        :return: HTTP response to the client
        """
        return Response(self.provider.provider_configuration.to_json(), content="application/json")

    def _get_approved_attributes(self, provider_supported_claims, authn_req):
        requested_claims = list(scope2claims(authn_req["scope"]).keys())
        if "claims" in authn_req:
            for k in ["id_token", "userinfo"]:
                if k in authn_req["claims"]:
                    requested_claims.extend(authn_req["claims"][k].keys())
        return set(provider_supported_claims).intersection(set(requested_claims))

    def _handle_authn_request(self, context):
        """
        Parse and verify the authentication request into an internal request.
        :type context: satosa.context.Context
        :rtype: internal_data.InternalRequest

        :param context: the current context
        :return: the internal request
        """
        request = urlencode(context.request)
        satosa_logging(logger, logging.DEBUG, "Authn req from client: {}".format(request), context.state)

        try:
            authn_req = self.provider.parse_authentication_request(request)
        except InvalidAuthenticationRequest as e:
            satosa_logging(logger, logging.ERROR, "Error in authn req: {}".format(str(e)), context.state)
            error_url = e.to_error_url()

            if error_url:
                return SeeOther(error_url)
            else:
                return BadRequest("Something went wrong: {}".format(str(e)))

        client_id = authn_req["client_id"]
        context.state[self.name] = {"oidc_request": request}
        hash_type = oidc_subject_type_to_hash_type(self.provider.clients[client_id].get("subject_type", "pairwise"))
        client_name = self.provider.clients[client_id].get("client_name")
        if client_name:
            # TODO should process client names for all languages, see OIDC Registration, Section 2.1
            requester_name = [{"lang": "en", "text": client_name}]
        else:
            requester_name = None
        internal_req = InternalRequest(hash_type, client_id, requester_name)

        internal_req.approved_attributes = self.converter.to_internal_filter(
            "openid",
            self._get_approved_attributes(self.provider.configuration_information["claims_supported"], authn_req),
        )
        return internal_req

    def handle_authn_request(self, context):
        """
        Handle an authentication request and pass it on to the backend.
        :type context: satosa.context.Context
        :rtype: oic.utils.http_util.Response

        :param context: the current context
        :return: HTTP response to the client
        """
        internal_req = self._handle_authn_request(context)
        if not isinstance(internal_req, InternalRequest):
            return internal_req
        return self.auth_req_callback_func(context, internal_req)

    def jwks(self, context):
        """
        Construct the JWKS document (served at /jwks).
        :type context: satosa.context.Context
        :rtype: oic.utils.http_util.Response

        :param context: the current context
        :return: HTTP response to the client
        """
        return Response(json.dumps(self.provider.jwks), content="application/json")

    def token_endpoint(self, context):
        """
        Handle token requests (served at /token).
        :type context: satosa.context.Context
        :rtype: oic.utils.http_util.Response

        :param context: the current context
        :return: HTTP response to the client
        """
        headers = {"Authorization": context.request_authorization}
        try:
            response = self.provider.handle_token_request(urlencode(context.request), headers)
            return Response(response.to_json(), content="application/json")
        except InvalidClientAuthentication as e:
            logger.debug("invalid client authentication at token endpoint", exc_info=True)
            error_resp = TokenErrorResponse(error="invalid_client", error_description=str(e))
            response = Unauthorized(
                error_resp.to_json(), headers=[("WWW-Authenticate", "Basic")], content="application/json"
            )
            return response
        except OAuthError as e:
            logger.debug("invalid request: %s", str(e), exc_info=True)
            error_resp = TokenErrorResponse(error=e.oauth_error, error_description=str(e))
            return BadRequest(error_resp.to_json(), content="application/json")

    def userinfo_endpoint(self, context):
        headers = {"Authorization": context.request_authorization}

        try:
            response = self.provider.handle_userinfo_request(urlencode(context.request), headers)
            return Response(response.to_json(), content="application/json")
        except (BearerTokenError, InvalidAccessToken) as e:
            error_resp = UserInfoErrorResponse(error="invalid_token", error_description=str(e))
            response = Unauthorized(
                error_resp.to_json(),
                headers=[("WWW-Authenticate", AccessToken.BEARER_TOKEN_TYPE)],
                content="application/json",
            )
            return response
Esempio n. 2
0
class OpenIDConnectFrontend(FrontendModule):
    """
    A OpenID Connect frontend module
    """
    def __init__(self, auth_req_callback_func, internal_attributes, conf,
                 base_url, name):
        self._validate_config(conf)
        super().__init__(auth_req_callback_func, internal_attributes, base_url,
                         name)

        self.config = conf
        self.signing_key = RSAKey(key=rsa_load(conf["signing_key_path"]),
                                  use="sig",
                                  alg="RS256")

    def _create_provider(self, endpoint_baseurl):
        response_types_supported = self.config["provider"].get(
            "response_types_supported", ["id_token"])
        subject_types_supported = self.config["provider"].get(
            "subject_types_supported", ["pairwise"])
        scopes_supported = self.config["provider"].get("scopes_supported",
                                                       ["openid"])
        capabilities = {
            "issuer":
            self.base_url,
            "authorization_endpoint":
            "{}/{}".format(endpoint_baseurl, AuthorizationEndpoint.url),
            "jwks_uri":
            "{}/jwks".format(endpoint_baseurl),
            "response_types_supported":
            response_types_supported,
            "id_token_signing_alg_values_supported": [self.signing_key.alg],
            "response_modes_supported": ["fragment", "query"],
            "subject_types_supported":
            subject_types_supported,
            "claim_types_supported": ["normal"],
            "claims_parameter_supported":
            True,
            "claims_supported": [
                attribute_map["openid"][0] for attribute_map in
                self.internal_attributes["attributes"].values()
                if "openid" in attribute_map
            ],
            "request_parameter_supported":
            False,
            "request_uri_parameter_supported":
            False,
            "scopes_supported":
            scopes_supported
        }

        if 'code' in response_types_supported:
            capabilities["token_endpoint"] = "{}/{}".format(
                endpoint_baseurl, TokenEndpoint.url)

        if self.config["provider"].get("client_registration_supported", False):
            capabilities["registration_endpoint"] = "{}/{}".format(
                endpoint_baseurl, RegistrationEndpoint.url)

        authz_state = self._init_authorization_state()
        db_uri = self.config.get("db_uri")
        cdb_file = self.config.get("client_db_path")
        if db_uri:
            cdb = MongoWrapper(db_uri, "satosa", "clients")
        elif cdb_file:
            with open(cdb_file) as f:
                cdb = json.loads(f.read())
        else:
            cdb = {}
        self.user_db = MongoWrapper(db_uri, "satosa",
                                    "authz_codes") if db_uri else {}
        self.provider = Provider(self.signing_key, capabilities, authz_state,
                                 cdb, Userinfo(self.user_db))

    def _init_authorization_state(self):
        sub_hash_salt = self.config.get("sub_hash_salt", rndstr(16))
        db_uri = self.config.get("db_uri")
        if db_uri:
            authz_code_db = MongoWrapper(db_uri, "satosa", "authz_codes")
            access_token_db = MongoWrapper(db_uri, "satosa", "access_tokens")
            refresh_token_db = MongoWrapper(db_uri, "satosa", "refresh_tokens")
            sub_db = MongoWrapper(db_uri, "satosa", "subject_identifiers")
        else:
            authz_code_db = None
            access_token_db = None
            refresh_token_db = None
            sub_db = None

        token_lifetimes = {
            k: self.config["provider"][k]
            for k in [
                "authorization_code_lifetime", "access_token_lifetime",
                "refresh_token_lifetime", "refresh_token_threshold"
            ] if k in self.config["provider"]
        }
        return AuthorizationState(
            HashBasedSubjectIdentifierFactory(sub_hash_salt), authz_code_db,
            access_token_db, refresh_token_db, sub_db, **token_lifetimes)

    def handle_authn_response(self,
                              context,
                              internal_resp,
                              extra_id_token_claims=None):
        """
        See super class method satosa.frontends.base.FrontendModule#handle_authn_response
        :type context: satosa.context.Context
        :type internal_response: satosa.internal.InternalData
        :rtype oic.utils.http_util.Response
        """

        auth_req = self._get_authn_request_from_state(context.state)

        attributes = self.converter.from_internal("openid",
                                                  internal_resp.attributes)
        self.user_db[internal_resp.subject_id] = {
            k: v[0]
            for k, v in attributes.items()
        }
        auth_resp = self.provider.authorize(auth_req, internal_resp.subject_id,
                                            extra_id_token_claims)

        del context.state[self.name]
        http_response = auth_resp.request(auth_req["redirect_uri"],
                                          should_fragment_encode(auth_req))
        return SeeOther(http_response)

    def handle_backend_error(self, exception):
        """
        See super class satosa.frontends.base.FrontendModule
        :type exception: satosa.exception.SATOSAError
        :rtype: oic.utils.http_util.Response
        """
        auth_req = self._get_authn_request_from_state(exception.state)
        # If the client sent us a state parameter, we should reflect it back according to the spec
        if 'state' in auth_req:
            error_resp = AuthorizationErrorResponse(
                error="access_denied",
                error_description=exception.message,
                state=auth_req['state'])
        else:
            error_resp = AuthorizationErrorResponse(
                error="access_denied", error_description=exception.message)
        satosa_logging(logger, logging.DEBUG, exception.message,
                       exception.state)
        return SeeOther(
            error_resp.request(auth_req["redirect_uri"],
                               should_fragment_encode(auth_req)))

    def register_endpoints(self, backend_names):
        """
        See super class satosa.frontends.base.FrontendModule
        :type backend_names: list[str]
        :rtype: list[(str, ((satosa.context.Context, Any) -> satosa.response.Response, Any))]
        :raise ValueError: if more than one backend is configured
        """
        backend_name = None
        if len(backend_names) != 1:
            # only supports one backend since there currently is no way to publish multiple authorization endpoints
            # in configuration information and there is no other standard way of authorization_endpoint discovery
            # similar to SAML entity discovery
            # this can be circumvented with a custom RequestMicroService which handles the routing based on something
            # in the authentication request
            logger.warn(
                "More than one backend is configured, make sure to provide a custom routing micro service to "
                "determine which backend should be used per request.")
        else:
            backend_name = backend_names[0]

        endpoint_baseurl = "{}/{}".format(self.base_url, self.name)
        self._create_provider(endpoint_baseurl)

        provider_config = ("^.well-known/openid-configuration$",
                           self.provider_config)
        jwks_uri = ("^{}/jwks$".format(self.name), self.jwks)

        if backend_name:
            # if there is only one backend, include its name in the path so the default routing can work
            auth_endpoint = "{}/{}/{}/{}".format(self.base_url, backend_name,
                                                 self.name,
                                                 AuthorizationEndpoint.url)
            self.provider.configuration_information[
                "authorization_endpoint"] = auth_endpoint
            auth_path = urlparse(auth_endpoint).path.lstrip("/")
        else:
            auth_path = "{}/{}".format(self.name, AuthorizationEndpoint.url)
        authentication = ("^{}$".format(auth_path), self.handle_authn_request)
        url_map = [provider_config, jwks_uri, authentication]

        if any("code" in v for v in self.provider.
               configuration_information["response_types_supported"]):
            self.provider.configuration_information[
                "token_endpoint"] = "{}/{}".format(endpoint_baseurl,
                                                   TokenEndpoint.url)
            token_endpoint = ("^{}/{}".format(self.name, TokenEndpoint.url),
                              self.token_endpoint)
            url_map.append(token_endpoint)

            self.provider.configuration_information[
                "userinfo_endpoint"] = "{}/{}".format(endpoint_baseurl,
                                                      UserinfoEndpoint.url)
            userinfo_endpoint = ("^{}/{}".format(self.name,
                                                 UserinfoEndpoint.url),
                                 self.userinfo_endpoint)
            url_map.append(userinfo_endpoint)
        if "registration_endpoint" in self.provider.configuration_information:
            client_registration = ("^{}/{}".format(self.name,
                                                   RegistrationEndpoint.url),
                                   self.client_registration)
            url_map.append(client_registration)

        return url_map

    def _validate_config(self, config):
        """
        Validates that all necessary config parameters are specified.
        :type config: dict[str, dict[str, Any] | str]
        :param config: the module config
        """
        if config is None:
            raise ValueError("OIDCFrontend conf can't be 'None'.")

        for k in {"signing_key_path", "provider"}:
            if k not in config:
                raise ValueError(
                    "Missing configuration parameter '{}' for OpenID Connect frontend."
                    .format(k))

    def _get_authn_request_from_state(self, state):
        """
        Extract the clietns request stoed in the SATOSA state.
        :type state: satosa.state.State
        :rtype: oic.oic.message.AuthorizationRequest

        :param state: the current state
        :return: the parsed authentication request
        """
        return AuthorizationRequest().deserialize(
            state[self.name]["oidc_request"])

    def client_registration(self, context):
        """
        Handle the OIDC dynamic client registration.
        :type context: satosa.context.Context
        :rtype: oic.utils.http_util.Response

        :param context: the current context
        :return: HTTP response to the client
        """
        try:
            resp = self.provider.handle_client_registration_request(
                json.dumps(context.request))
            return Created(resp.to_json(), content="application/json")
        except InvalidClientRegistrationRequest as e:
            return BadRequest(e.to_json(), content="application/json")

    def provider_config(self, context):
        """
        Construct the provider configuration information (served at /.well-known/openid-configuration).
        :type context: satosa.context.Context
        :rtype: oic.utils.http_util.Response

        :param context: the current context
        :return: HTTP response to the client
        """
        return Response(self.provider.provider_configuration.to_json(),
                        content="application/json")

    def _get_approved_attributes(self, provider_supported_claims, authn_req):
        requested_claims = list(scope2claims(authn_req["scope"]).keys())
        if "claims" in authn_req:
            for k in ["id_token", "userinfo"]:
                if k in authn_req["claims"]:
                    requested_claims.extend(authn_req["claims"][k].keys())
        return set(provider_supported_claims).intersection(
            set(requested_claims))

    def _handle_authn_request(self, context):
        """
        Parse and verify the authentication request into an internal request.
        :type context: satosa.context.Context
        :rtype: satosa.internal.InternalData

        :param context: the current context
        :return: the internal request
        """
        request = urlencode(context.request)
        satosa_logging(logger, logging.DEBUG,
                       "Authn req from client: {}".format(request),
                       context.state)

        try:
            authn_req = self.provider.parse_authentication_request(request)
        except InvalidAuthenticationRequest as e:
            satosa_logging(logger, logging.ERROR,
                           "Error in authn req: {}".format(str(e)),
                           context.state)
            error_url = e.to_error_url()

            if error_url:
                return SeeOther(error_url)
            else:
                return BadRequest("Something went wrong: {}".format(str(e)))

        client_id = authn_req["client_id"]
        context.state[self.name] = {"oidc_request": request}
        subject_type = self.provider.clients[client_id].get(
            "subject_type", "pairwise")
        client_name = self.provider.clients[client_id].get("client_name")
        if client_name:
            # TODO should process client names for all languages, see OIDC Registration, Section 2.1
            requester_name = [{"lang": "en", "text": client_name}]
        else:
            requester_name = None
        internal_req = InternalData(
            subject_type=subject_type,
            requester=client_id,
            requester_name=requester_name,
        )

        internal_req.attributes = self.converter.to_internal_filter(
            "openid",
            self._get_approved_attributes(
                self.provider.configuration_information["claims_supported"],
                authn_req))
        return internal_req

    def handle_authn_request(self, context):
        """
        Handle an authentication request and pass it on to the backend.
        :type context: satosa.context.Context
        :rtype: oic.utils.http_util.Response

        :param context: the current context
        :return: HTTP response to the client
        """
        internal_req = self._handle_authn_request(context)
        if not isinstance(internal_req, InternalData):
            return internal_req
        return self.auth_req_callback_func(context, internal_req)

    def jwks(self, context):
        """
        Construct the JWKS document (served at /jwks).
        :type context: satosa.context.Context
        :rtype: oic.utils.http_util.Response

        :param context: the current context
        :return: HTTP response to the client
        """
        return Response(json.dumps(self.provider.jwks),
                        content="application/json")

    def token_endpoint(self, context):
        """
        Handle token requests (served at /token).
        :type context: satosa.context.Context
        :rtype: oic.utils.http_util.Response

        :param context: the current context
        :return: HTTP response to the client
        """
        headers = {"Authorization": context.request_authorization}
        try:
            response = self.provider.handle_token_request(
                urlencode(context.request), headers)
            return Response(response.to_json(), content="application/json")
        except InvalidClientAuthentication as e:
            logger.debug('invalid client authentication at token endpoint',
                         exc_info=True)
            error_resp = TokenErrorResponse(error='invalid_client',
                                            error_description=str(e))
            response = Unauthorized(error_resp.to_json(),
                                    headers=[("WWW-Authenticate", "Basic")],
                                    content="application/json")
            return response
        except OAuthError as e:
            logger.debug('invalid request: %s', str(e), exc_info=True)
            error_resp = TokenErrorResponse(error=e.oauth_error,
                                            error_description=str(e))
            return BadRequest(error_resp.to_json(), content="application/json")

    def userinfo_endpoint(self, context):
        headers = {"Authorization": context.request_authorization}

        try:
            response = self.provider.handle_userinfo_request(
                urlencode(context.request), headers)
            return Response(response.to_json(), content="application/json")
        except (BearerTokenError, InvalidAccessToken) as e:
            error_resp = UserInfoErrorResponse(error='invalid_token',
                                               error_description=str(e))
            response = Unauthorized(error_resp.to_json(),
                                    headers=[("WWW-Authenticate",
                                              AccessToken.BEARER_TOKEN_TYPE)],
                                    content="application/json")
            return response
Esempio n. 3
0
class OIDCPlugin(KosekiPlugin):
    def config(self) -> dict:
        return {}

    def plugin_enable(self) -> None:
        # pylint: disable=attribute-defined-outside-init
        self.signing_key = RSAKey(key=rsa_load("signing_key.pem"),
                                  use="sig",
                                  alg="RS256")
        # pylint: disable=attribute-defined-outside-init
        self.oidc_configuration_information = {
            "issuer":
            self.app.config["URL_BASE"].replace("http://", "https://"),
            "authorization_endpoint":
            self.app.config["URL_BASE"] + "/oidc/authorization",
            "jwks_uri":
            self.app.config["URL_BASE"] + "/oidc/jwks",
            "token_endpoint":
            self.app.config["URL_BASE"] + "/oidc/token",
            "userinfo_endpoint":
            self.app.config["URL_BASE"] + "/oidc/userinfo",
            "response_types_supported": ["code"],
            "id_token_signing_alg_values_supported": [self.signing_key.alg],
            "response_modes_supported": ["fragment", "query"],
            "subject_types_supported": ["public", "pairwise"],
            "grant_types_supported": ["authorization_code", "implicit"],
            "claim_types_supported": ["normal"],
            "claims_parameter_supported":
            True,
            "claims_supported":
            ["sub", "name", "given_name", "family_name", "email", "profile"],
            "scopes_supported": ["openid", "email", "profile"]
        }
        subject_id_factory = HashBasedSubjectIdentifierFactory(
            self.app.config["SECRET_KEY"])
        # pylint: disable=attribute-defined-outside-init
        self.provider = Provider(
            self.signing_key, self.oidc_configuration_information,
            AuthorizationState(
                subject_id_factory,
                SQLWrapper(self.storage, 'authz_codes'),
                SQLWrapper(self.storage, 'access_tokens'),
                SQLWrapper(self.storage, 'refresh_tokens'),
                SQLWrapper(self.storage, 'subject_identifiers', True),
            ), SQLWrapper(self.storage, 'clients'),
            Userinfo(PersonWrapper(self.storage)))

    def create_blueprint(self) -> Blueprint:
        blueprint: Blueprint = Blueprint("oidc",
                                         __name__,
                                         template_folder="./templates")
        blueprint.add_url_rule("/.well-known/openid-configuration",
                               None,
                               self.oidc_config,
                               methods=["GET"])
        blueprint.add_url_rule("/oidc/jwks",
                               None,
                               self.oidc_jwks,
                               methods=["GET"])
        blueprint.add_url_rule("/oidc/authorization",
                               None,
                               self.auth.require_session(
                                   self.oidc_authorization, None),
                               methods=["GET"])
        blueprint.add_url_rule("/oidc/token",
                               None,
                               self.oidc_token,
                               methods=["GET", "POST"])
        blueprint.add_url_rule("/oidc/userinfo",
                               None,
                               self.oidc_userinfo,
                               methods=["GET", "POST"])
        return blueprint

    def oidc_config(self) -> Union[str, Response]:
        return Response(self.provider.provider_configuration.to_json(),
                        mimetype="application/json")

    def oidc_jwks(self) -> Union[str, Response]:
        return jsonify(self.provider.jwks)

    # require_session
    def oidc_authorization(self) -> Union[str, Response]:
        person: Person = self.storage.session.query(Person).filter_by(
            uid=self.util.current_user()).scalar()
        if person.state != "active":
            logging.warning(
                "User {} tried to log in via OIDC without membership!",
                person.uid)
            self.util.alert(
                KosekiAlert(
                    KosekiAlertType.DANGER,
                    "Missing membership!",
                    "OpenID Connect login was successful, but active " +
                    "membership is required to log in to 3rd party services.",
                ))
            return render_template("oidc.html")

        try:
            # Override claims to always only and at least return email for CF
            authn_req = AuthorizationRequest().from_dict({
                **request.args, "claims": {
                    "id_token": {
                        "email": {
                            "essential": True
                        }
                    }
                }
            })
            logging.debug(authn_req)
            authn_req.verify()
            if authn_req["client_id"] not in self.provider.clients:
                self.util.alert(
                    KosekiAlert(
                        KosekiAlertType.DANGER,
                        "Invalid Client",
                        "Unknown Client ID. All applications must be pre-registered.",
                    ))
                return render_template("oidc.html")
            client = self.provider.clients[authn_req["client_id"]]
            if client["redirect_uri"] != authn_req["redirect_uri"]:
                self.util.alert(
                    KosekiAlert(
                        KosekiAlertType.DANGER,
                        "Incorrect Redirect URI",
                        "Unauthorized redirect URI for this client.",
                    ))
                return render_template("oidc.html")

            authn_response = self.provider.authorize(
                authn_req, str(self.util.current_user()))
            return_url = authn_response.request(
                authn_req["redirect_uri"], should_fragment_encode(authn_req))
            return redirect(return_url)
        except Exception as err:  # pylint: disable=broad-except
            logging.error(err)
            self.util.alert(
                KosekiAlert(
                    KosekiAlertType.DANGER,
                    "OIDC Error",
                    "Error: {}".format(str(err)),
                ))
            return render_template("oidc.html")

    def oidc_token(self) -> Union[str, Response]:
        print(request)
        try:
            token_response = self.provider.handle_token_request(
                request.get_data().decode("utf-8"), request.headers)
            return Response(token_response.to_json(),
                            mimetype="application/json")
        except InvalidClientAuthentication as err:
            error_resp = TokenErrorResponse(error="invalid_client",
                                            error_description=str(err))
            http_response = Response(error_resp.to_json(),
                                     status=401,
                                     mimetype="application/json")
            http_response.headers["WWW-Authenticate"] = "Basic"
            logging.error(error_resp.to_dict())
            return http_response
        except OAuthError as err:
            error_resp = TokenErrorResponse(error=err.oauth_error,
                                            error_description=str(err))
            return Response(error_resp.to_json(),
                            status=400,
                            mimetype="application/json")

    def oidc_userinfo(self) -> Union[str, Response]:
        try:
            response = self.provider.handle_userinfo_request(
                request.get_data().decode("utf-8"), request.headers)
            return Response(response.to_json(), mimetype="application/json")
        except (BearerTokenError, InvalidAccessToken) as err:
            error_resp = UserInfoErrorResponse(error="invalid_token",
                                               error_description=str(err))
            http_response = Response(error_resp.to_json(),
                                     status=401,
                                     mimetype="application/json")
            http_response.headers[
                "WWW-Authenticate"] = AccessToken.BEARER_TOKEN_TYPE
            logging.error(error_resp.to_dict())
            return http_response