예제 #1
0
    def __init__(self, hs: "HomeServer"):
        self._callback_url = hs.config.oidc_callback_url  # type: str
        self._scopes = hs.config.oidc_scopes  # type: List[str]
        self._client_auth = ClientAuth(
            hs.config.oidc_client_id,
            hs.config.oidc_client_secret,
            hs.config.oidc_client_auth_method,
        )  # type: ClientAuth
        self._client_auth_method = hs.config.oidc_client_auth_method  # type: str
        self._provider_metadata = OpenIDProviderMetadata(
            issuer=hs.config.oidc_issuer,
            authorization_endpoint=hs.config.oidc_authorization_endpoint,
            token_endpoint=hs.config.oidc_token_endpoint,
            userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
            jwks_uri=hs.config.oidc_jwks_uri,
        )  # type: OpenIDProviderMetadata
        self._provider_needs_discovery = hs.config.oidc_discover  # type: bool
        self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
            hs.config.oidc_user_mapping_provider_config
        )  # type: OidcMappingProvider
        self._skip_verification = hs.config.oidc_skip_verification  # type: bool

        self._http_client = hs.get_proxied_http_client()
        self._auth_handler = hs.get_auth_handler()
        self._registration_handler = hs.get_registration_handler()
        self._datastore = hs.get_datastore()
        self._clock = hs.get_clock()
        self._hostname = hs.hostname  # type: str
        self._server_name = hs.config.server_name  # type: str
        self._macaroon_secret_key = hs.config.macaroon_secret_key
        self._error_template = hs.config.sso_error_template

        # identifier for the external_ids table
        self._auth_provider_id = "oidc"
예제 #2
0
 def test_validate_request_object_signing_alg_values_supported(self):
     self._call_validate_array(
         'request_object_signing_alg_values_supported', ['none', 'RS256'])
     metadata = OpenIDProviderMetadata(
         {'request_object_signing_alg_values_supported': ['RS512']})
     with self.assertRaises(ValueError) as cm:
         metadata.validate_request_object_signing_alg_values_supported()
         self.assertIn('SHOULD support none and RS256', str(cm.exception))
예제 #3
0
 def test_validate_id_token_signing_alg_values_supported(self):
     self._call_validate_array(
         'id_token_signing_alg_values_supported',
         ['RS256'],
         required=True,
     )
     metadata = OpenIDProviderMetadata(
         {'id_token_signing_alg_values_supported': ['none']})
     with self.assertRaises(ValueError) as cm:
         metadata.validate_id_token_signing_alg_values_supported()
         self.assertIn('RS256', str(cm.exception))
예제 #4
0
    def __init__(
        self,
        hs: "HomeServer",
        token_generator: "OidcSessionTokenGenerator",
        provider: OidcProviderConfig,
    ):
        self._store = hs.get_datastore()

        self._token_generator = token_generator

        self._callback_url = hs.config.oidc_callback_url  # type: str

        self._scopes = provider.scopes
        self._user_profile_method = provider.user_profile_method
        self._client_auth = ClientAuth(
            provider.client_id,
            provider.client_secret,
            provider.client_auth_method,
        )  # type: ClientAuth
        self._client_auth_method = provider.client_auth_method
        self._provider_metadata = OpenIDProviderMetadata(
            issuer=provider.issuer,
            authorization_endpoint=provider.authorization_endpoint,
            token_endpoint=provider.token_endpoint,
            userinfo_endpoint=provider.userinfo_endpoint,
            jwks_uri=provider.jwks_uri,
        )  # type: OpenIDProviderMetadata
        self._provider_needs_discovery = provider.discover
        self._user_mapping_provider = provider.user_mapping_provider_class(
            provider.user_mapping_provider_config)
        self._skip_verification = provider.skip_verification
        self._allow_existing_users = provider.allow_existing_users

        self._http_client = hs.get_proxied_http_client()
        self._server_name = hs.config.server_name  # type: str

        # identifier for the external_ids table
        self.idp_id = provider.idp_id

        # user-facing name of this auth provider
        self.idp_name = provider.idp_name

        # MXC URI for icon for this auth provider
        self.idp_icon = provider.idp_icon

        # optional brand identifier for this auth provider
        self.idp_brand = provider.idp_brand

        self._sso_handler = hs.get_sso_handler()

        self._sso_handler.register_identity_provider(self)
예제 #5
0
    def _call_validate_boolean(self, key, default_value=False):
        def _validate(metadata):
            getattr(metadata, 'validate_' + key)()

        metadata = OpenIDProviderMetadata()
        _validate(metadata)
        self.assertEqual(getattr(metadata, key), default_value)

        metadata = OpenIDProviderMetadata({key: 'str'})
        with self.assertRaises(ValueError) as cm:
            _validate(metadata)
            self.assertIn('MUST be boolean', str(cm.exception))
        metadata = OpenIDProviderMetadata({key: True})
        _validate(metadata)
예제 #6
0
    def _call_validate_array(self, key, valid_value, required=False):
        def _validate(metadata):
            getattr(metadata, 'validate_' + key)()

        metadata = OpenIDProviderMetadata()
        if required:
            with self.assertRaises(ValueError) as cm:
                _validate(metadata)
            self.assertEqual('"{}" is required'.format(key), str(cm.exception))
        else:
            _validate(metadata)

        # not array
        metadata = OpenIDProviderMetadata({key: 'foo'})
        with self.assertRaises(ValueError) as cm:
            _validate(metadata)
        self.assertIn('JSON array', str(cm.exception))

        # valid
        metadata = OpenIDProviderMetadata({key: valid_value})
        _validate(metadata)
예제 #7
0
class OidcHandler:
    """Handles requests related to the OpenID Connect login flow.
    """

    def __init__(self, hs: HomeServer):
        self._callback_url = hs.config.oidc_callback_url  # type: str
        self._scopes = hs.config.oidc_scopes  # type: List[str]
        self._client_auth = ClientAuth(
            hs.config.oidc_client_id,
            hs.config.oidc_client_secret,
            hs.config.oidc_client_auth_method,
        )  # type: ClientAuth
        self._client_auth_method = hs.config.oidc_client_auth_method  # type: str
        self._provider_metadata = OpenIDProviderMetadata(
            issuer=hs.config.oidc_issuer,
            authorization_endpoint=hs.config.oidc_authorization_endpoint,
            token_endpoint=hs.config.oidc_token_endpoint,
            userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
            jwks_uri=hs.config.oidc_jwks_uri,
        )  # type: OpenIDProviderMetadata
        self._provider_needs_discovery = hs.config.oidc_discover  # type: bool
        self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
            hs.config.oidc_user_mapping_provider_config
        )  # type: OidcMappingProvider
        self._skip_verification = hs.config.oidc_skip_verification  # type: bool

        self._http_client = hs.get_proxied_http_client()
        self._auth_handler = hs.get_auth_handler()
        self._registration_handler = hs.get_registration_handler()
        self._datastore = hs.get_datastore()
        self._clock = hs.get_clock()
        self._hostname = hs.hostname  # type: str
        self._server_name = hs.config.server_name  # type: str
        self._macaroon_secret_key = hs.config.macaroon_secret_key
        self._error_template = load_jinja2_templates(
            hs.config.sso_template_dir, ["sso_error.html"]
        )[0]

        # identifier for the external_ids table
        self._auth_provider_id = "oidc"

    def _render_error(
        self, request, error: str, error_description: Optional[str] = None
    ) -> None:
        """Renders the error template and respond with it.

        This is used to show errors to the user. The template of this page can
        be found under ``synapse/res/templates/sso_error.html``.

        Args:
            request: The incoming request from the browser.
                We'll respond with an HTML page describing the error.
            error: A technical identifier for this error. Those include
                well-known OAuth2/OIDC error types like invalid_request or
                access_denied.
            error_description: A human-readable description of the error.
        """
        html = self._error_template.render(
            error=error, error_description=error_description
        )
        respond_with_html(request, 400, html)

    def _validate_metadata(self):
        """Verifies the provider metadata.

        This checks the validity of the currently loaded provider. Not
        everything is checked, only:

          - ``issuer``
          - ``authorization_endpoint``
          - ``token_endpoint``
          - ``response_types_supported`` (checks if "code" is in it)
          - ``jwks_uri``

        Raises:
            ValueError: if something in the provider is not valid
        """
        # Skip verification to allow non-compliant providers (e.g. issuers not running on a secure origin)
        if self._skip_verification is True:
            return

        m = self._provider_metadata
        m.validate_issuer()
        m.validate_authorization_endpoint()
        m.validate_token_endpoint()

        if m.get("token_endpoint_auth_methods_supported") is not None:
            m.validate_token_endpoint_auth_methods_supported()
            if (
                self._client_auth_method
                not in m["token_endpoint_auth_methods_supported"]
            ):
                raise ValueError(
                    '"{auth_method}" not in "token_endpoint_auth_methods_supported" ({supported!r})'.format(
                        auth_method=self._client_auth_method,
                        supported=m["token_endpoint_auth_methods_supported"],
                    )
                )

        if m.get("response_types_supported") is not None:
            m.validate_response_types_supported()

            if "code" not in m["response_types_supported"]:
                raise ValueError(
                    '"code" not in "response_types_supported" (%r)'
                    % (m["response_types_supported"],)
                )

        # If the openid scope was not requested, we need a userinfo endpoint to fetch user infos
        if self._uses_userinfo:
            if m.get("userinfo_endpoint") is None:
                raise ValueError(
                    'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested'
                )
        else:
            # If we're not using userinfo, we need a valid jwks to validate the ID token
            if m.get("jwks") is None:
                if m.get("jwks_uri") is not None:
                    m.validate_jwks_uri()
                else:
                    raise ValueError('"jwks_uri" must be set')

    @property
    def _uses_userinfo(self) -> bool:
        """Returns True if the ``userinfo_endpoint`` should be used.

        This is based on the requested scopes: if the scopes include
        ``openid``, the provider should give use an ID token containing the
        user informations. If not, we should fetch them using the
        ``access_token`` with the ``userinfo_endpoint``.
        """

        # Maybe that should be user-configurable and not inferred?
        return "openid" not in self._scopes

    async def load_metadata(self) -> OpenIDProviderMetadata:
        """Load and validate the provider metadata.

        The values metadatas are discovered if ``oidc_config.discovery`` is
        ``True`` and then cached.

        Raises:
            ValueError: if something in the provider is not valid

        Returns:
            The provider's metadata.
        """
        # If we are using the OpenID Discovery documents, it needs to be loaded once
        # FIXME: should there be a lock here?
        if self._provider_needs_discovery:
            url = get_well_known_url(self._provider_metadata["issuer"], external=True)
            metadata_response = await self._http_client.get_json(url)
            # TODO: maybe update the other way around to let user override some values?
            self._provider_metadata.update(metadata_response)
            self._provider_needs_discovery = False

        self._validate_metadata()

        return self._provider_metadata

    async def load_jwks(self, force: bool = False) -> JWKS:
        """Load the JSON Web Key Set used to sign ID tokens.

        If we're not using the ``userinfo_endpoint``, user infos are extracted
        from the ID token, which is a JWT signed by keys given by the provider.
        The keys are then cached.

        Args:
            force: Force reloading the keys.

        Returns:
            The key set

            Looks like this::

                {
                    'keys': [
                        {
                            'kid': 'abcdef',
                            'kty': 'RSA',
                            'alg': 'RS256',
                            'use': 'sig',
                            'e': 'XXXX',
                            'n': 'XXXX',
                        }
                    ]
                }
        """
        if self._uses_userinfo:
            # We're not using jwt signing, return an empty jwk set
            return {"keys": []}

        # First check if the JWKS are loaded in the provider metadata.
        # It can happen either if the provider gives its JWKS in the discovery
        # document directly or if it was already loaded once.
        metadata = await self.load_metadata()
        jwk_set = metadata.get("jwks")
        if jwk_set is not None and not force:
            return jwk_set

        # Loading the JWKS using the `jwks_uri` metadata
        uri = metadata.get("jwks_uri")
        if not uri:
            raise RuntimeError('Missing "jwks_uri" in metadata')

        jwk_set = await self._http_client.get_json(uri)

        # Caching the JWKS in the provider's metadata
        self._provider_metadata["jwks"] = jwk_set
        return jwk_set

    async def _exchange_code(self, code: str) -> Token:
        """Exchange an authorization code for a token.

        This calls the ``token_endpoint`` with the authorization code we
        received in the callback to exchange it for a token. The call uses the
        ``ClientAuth`` to authenticate with the client with its ID and secret.

        See:
           https://tools.ietf.org/html/rfc6749#section-3.2
           https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint

        Args:
            code: The authorization code we got from the callback.

        Returns:
            A dict containing various tokens.

            May look like this::

                {
                    'token_type': 'bearer',
                    'access_token': 'abcdef',
                    'expires_in': 3599,
                    'id_token': 'ghijkl',
                    'refresh_token': 'mnopqr',
                }

        Raises:
            OidcError: when the ``token_endpoint`` returned an error.
        """
        metadata = await self.load_metadata()
        token_endpoint = metadata.get("token_endpoint")
        headers = {
            "Content-Type": "application/x-www-form-urlencoded",
            "User-Agent": self._http_client.user_agent,
            "Accept": "application/json",
        }

        args = {
            "grant_type": "authorization_code",
            "code": code,
            "redirect_uri": self._callback_url,
        }
        body = urlencode(args, True)

        # Fill the body/headers with credentials
        uri, headers, body = self._client_auth.prepare(
            method="POST", uri=token_endpoint, headers=headers, body=body
        )
        headers = {k: [v] for (k, v) in headers.items()}

        # Do the actual request
        # We're not using the SimpleHttpClient util methods as we don't want to
        # check the HTTP status code and we do the body encoding ourself.
        response = await self._http_client.request(
            method="POST", uri=uri, data=body.encode("utf-8"), headers=headers,
        )

        # This is used in multiple error messages below
        status = "{code} {phrase}".format(
            code=response.code, phrase=response.phrase.decode("utf-8")
        )

        resp_body = await make_deferred_yieldable(readBody(response))

        if response.code >= 500:
            # In case of a server error, we should first try to decode the body
            # and check for an error field. If not, we respond with a generic
            # error message.
            try:
                resp = json.loads(resp_body.decode("utf-8"))
                error = resp["error"]
                description = resp.get("error_description", error)
            except (ValueError, KeyError):
                # Catch ValueError for the JSON decoding and KeyError for the "error" field
                error = "server_error"
                description = (
                    (
                        'Authorization server responded with a "{status}" error '
                        "while exchanging the authorization code."
                    ).format(status=status),
                )

            raise OidcError(error, description)

        # Since it is a not a 5xx code, body should be a valid JSON. It will
        # raise if not.
        resp = json.loads(resp_body.decode("utf-8"))

        if "error" in resp:
            error = resp["error"]
            # In case the authorization server responded with an error field,
            # it should be a 4xx code. If not, warn about it but don't do
            # anything special and report the original error message.
            if response.code < 400:
                logger.debug(
                    "Invalid response from the authorization server: "
                    'responded with a "{status}" '
                    "but body has an error field: {error!r}".format(
                        status=status, error=resp["error"]
                    )
                )

            description = resp.get("error_description", error)
            raise OidcError(error, description)

        # Now, this should not be an error. According to RFC6749 sec 5.1, it
        # should be a 200 code. We're a bit more flexible than that, and will
        # only throw on a 4xx code.
        if response.code >= 400:
            description = (
                'Authorization server responded with a "{status}" error '
                'but did not include an "error" field in its response.'.format(
                    status=status
                )
            )
            logger.warning(description)
            # Body was still valid JSON. Might be useful to log it for debugging.
            logger.warning("Code exchange response: {resp!r}".format(resp=resp))
            raise OidcError("server_error", description)

        return resp

    async def _fetch_userinfo(self, token: Token) -> UserInfo:
        """Fetch user informations from the ``userinfo_endpoint``.

        Args:
            token: the token given by the ``token_endpoint``.
                Must include an ``access_token`` field.

        Returns:
            UserInfo: an object representing the user.
        """
        metadata = await self.load_metadata()

        resp = await self._http_client.get_json(
            metadata["userinfo_endpoint"],
            headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
        )

        return UserInfo(resp)

    async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
        """Return an instance of UserInfo from token's ``id_token``.

        Args:
            token: the token given by the ``token_endpoint``.
                Must include an ``id_token`` field.
            nonce: the nonce value originally sent in the initial authorization
                request. This value should match the one inside the token.

        Returns:
            An object representing the user.
        """
        metadata = await self.load_metadata()
        claims_params = {
            "nonce": nonce,
            "client_id": self._client_auth.client_id,
        }
        if "access_token" in token:
            # If we got an `access_token`, there should be an `at_hash` claim
            # in the `id_token` that we can check against.
            claims_params["access_token"] = token["access_token"]
            claims_cls = CodeIDToken
        else:
            claims_cls = ImplicitIDToken

        alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])

        jwt = JsonWebToken(alg_values)

        claim_options = {"iss": {"values": [metadata["issuer"]]}}

        # Try to decode the keys in cache first, then retry by forcing the keys
        # to be reloaded
        jwk_set = await self.load_jwks()
        try:
            claims = jwt.decode(
                token["id_token"],
                key=jwk_set,
                claims_cls=claims_cls,
                claims_options=claim_options,
                claims_params=claims_params,
            )
        except ValueError:
            logger.info("Reloading JWKS after decode error")
            jwk_set = await self.load_jwks(force=True)  # try reloading the jwks
            claims = jwt.decode(
                token["id_token"],
                key=jwk_set,
                claims_cls=claims_cls,
                claims_options=claim_options,
                claims_params=claims_params,
            )

        claims.validate(leeway=120)  # allows 2 min of clock skew
        return UserInfo(claims)

    async def handle_redirect_request(
        self,
        request: SynapseRequest,
        client_redirect_url: bytes,
        ui_auth_session_id: Optional[str] = None,
    ) -> str:
        """Handle an incoming request to /login/sso/redirect

        It returns a redirect to the authorization endpoint with a few
        parameters:

          - ``client_id``: the client ID set in ``oidc_config.client_id``
          - ``response_type``: ``code``
          - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
          - ``scope``: the list of scopes set in ``oidc_config.scopes``
          - ``state``: a random string
          - ``nonce``: a random string

        In addition generating a redirect URL, we are setting a cookie with
        a signed macaroon token containing the state, the nonce and the
        client_redirect_url params. Those are then checked when the client
        comes back from the provider.

        Args:
            request: the incoming request from the browser.
                We'll respond to it with a redirect and a cookie.
            client_redirect_url: the URL that we should redirect the client to
                when everything is done
            ui_auth_session_id: The session ID of the ongoing UI Auth (or
                None if this is a login).

        Returns:
            The redirect URL to the authorization endpoint.

        """

        state = generate_token()
        nonce = generate_token()

        cookie = self._generate_oidc_session_token(
            state=state,
            nonce=nonce,
            client_redirect_url=client_redirect_url.decode(),
            ui_auth_session_id=ui_auth_session_id,
        )
        request.addCookie(
            SESSION_COOKIE_NAME,
            cookie,
            path="/_synapse/oidc",
            max_age="3600",
            httpOnly=True,
            sameSite="lax",
        )

        metadata = await self.load_metadata()
        authorization_endpoint = metadata.get("authorization_endpoint")
        return prepare_grant_uri(
            authorization_endpoint,
            client_id=self._client_auth.client_id,
            response_type="code",
            redirect_uri=self._callback_url,
            scope=self._scopes,
            state=state,
            nonce=nonce,
        )

    async def handle_oidc_callback(self, request: SynapseRequest) -> None:
        """Handle an incoming request to /_synapse/oidc/callback

        Since we might want to display OIDC-related errors in a user-friendly
        way, we don't raise SynapseError from here. Instead, we call
        ``self._render_error`` which displays an HTML page for the error.

        Most of the OpenID Connect logic happens here:

          - first, we check if there was any error returned by the provider and
            display it
          - then we fetch the session cookie, decode and verify it
          - the ``state`` query parameter should match with the one stored in the
            session cookie
          - once we known this session is legit, exchange the code with the
            provider using the ``token_endpoint`` (see ``_exchange_code``)
          - once we have the token, use it to either extract the UserInfo from
            the ``id_token`` (``_parse_id_token``), or use the ``access_token``
            to fetch UserInfo from the ``userinfo_endpoint``
            (``_fetch_userinfo``)
          - map those UserInfo to a Matrix user (``_map_userinfo_to_user``) and
            finish the login

        Args:
            request: the incoming request from the browser.
        """

        # The provider might redirect with an error.
        # In that case, just display it as-is.
        if b"error" in request.args:
            # error response from the auth server. see:
            #  https://tools.ietf.org/html/rfc6749#section-4.1.2.1
            #  https://openid.net/specs/openid-connect-core-1_0.html#AuthError
            error = request.args[b"error"][0].decode()
            description = request.args.get(b"error_description", [b""])[0].decode()

            # Most of the errors returned by the provider could be due by
            # either the provider misbehaving or Synapse being misconfigured.
            # The only exception of that is "access_denied", where the user
            # probably cancelled the login flow. In other cases, log those errors.
            if error != "access_denied":
                logger.error("Error from the OIDC provider: %s %s", error, description)

            self._render_error(request, error, description)
            return

        # otherwise, it is presumably a successful response. see:
        #   https://tools.ietf.org/html/rfc6749#section-4.1.2

        # Fetch the session cookie
        session = request.getCookie(SESSION_COOKIE_NAME)  # type: Optional[bytes]
        if session is None:
            logger.info("No session cookie found")
            self._render_error(request, "missing_session", "No session cookie found")
            return

        # Remove the cookie. There is a good chance that if the callback failed
        # once, it will fail next time and the code will already be exchanged.
        # Removing it early avoids spamming the provider with token requests.
        request.addCookie(
            SESSION_COOKIE_NAME,
            b"",
            path="/_synapse/oidc",
            expires="Thu, Jan 01 1970 00:00:00 UTC",
            httpOnly=True,
            sameSite="lax",
        )

        # Check for the state query parameter
        if b"state" not in request.args:
            logger.info("State parameter is missing")
            self._render_error(request, "invalid_request", "State parameter is missing")
            return

        state = request.args[b"state"][0].decode()

        # Deserialize the session token and verify it.
        try:
            (
                nonce,
                client_redirect_url,
                ui_auth_session_id,
            ) = self._verify_oidc_session_token(session, state)
        except MacaroonDeserializationException as e:
            logger.exception("Invalid session")
            self._render_error(request, "invalid_session", str(e))
            return
        except MacaroonInvalidSignatureException as e:
            logger.exception("Could not verify session")
            self._render_error(request, "mismatching_session", str(e))
            return

        # Exchange the code with the provider
        if b"code" not in request.args:
            logger.info("Code parameter is missing")
            self._render_error(request, "invalid_request", "Code parameter is missing")
            return

        logger.debug("Exchanging code")
        code = request.args[b"code"][0].decode()
        try:
            token = await self._exchange_code(code)
        except OidcError as e:
            logger.exception("Could not exchange code")
            self._render_error(request, e.error, e.error_description)
            return

        logger.debug("Successfully obtained OAuth2 access token")

        # Now that we have a token, get the userinfo, either by decoding the
        # `id_token` or by fetching the `userinfo_endpoint`.
        if self._uses_userinfo:
            logger.debug("Fetching userinfo")
            try:
                userinfo = await self._fetch_userinfo(token)
            except Exception as e:
                logger.exception("Could not fetch userinfo")
                self._render_error(request, "fetch_error", str(e))
                return
        else:
            logger.debug("Extracting userinfo from id_token")
            try:
                userinfo = await self._parse_id_token(token, nonce=nonce)
            except Exception as e:
                logger.exception("Invalid id_token")
                self._render_error(request, "invalid_token", str(e))
                return

        # Call the mapper to register/login the user
        try:
            user_id = await self._map_userinfo_to_user(userinfo, token)
        except MappingException as e:
            logger.exception("Could not map user")
            self._render_error(request, "mapping_error", str(e))
            return

        # and finally complete the login
        if ui_auth_session_id:
            await self._auth_handler.complete_sso_ui_auth(
                user_id, ui_auth_session_id, request
            )
        else:
            await self._auth_handler.complete_sso_login(
                user_id, request, client_redirect_url
            )

    def _generate_oidc_session_token(
        self,
        state: str,
        nonce: str,
        client_redirect_url: str,
        ui_auth_session_id: Optional[str],
        duration_in_ms: int = (60 * 60 * 1000),
    ) -> str:
        """Generates a signed token storing data about an OIDC session.

        When Synapse initiates an authorization flow, it creates a random state
        and a random nonce. Those parameters are given to the provider and
        should be verified when the client comes back from the provider.
        It is also used to store the client_redirect_url, which is used to
        complete the SSO login flow.

        Args:
            state: The ``state`` parameter passed to the OIDC provider.
            nonce: The ``nonce`` parameter passed to the OIDC provider.
            client_redirect_url: The URL the client gave when it initiated the
                flow.
            ui_auth_session_id: The session ID of the ongoing UI Auth (or
                None if this is a login).
            duration_in_ms: An optional duration for the token in milliseconds.
                Defaults to an hour.

        Returns:
            A signed macaroon token with the session informations.
        """
        macaroon = pymacaroons.Macaroon(
            location=self._server_name, identifier="key", key=self._macaroon_secret_key,
        )
        macaroon.add_first_party_caveat("gen = 1")
        macaroon.add_first_party_caveat("type = session")
        macaroon.add_first_party_caveat("state = %s" % (state,))
        macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
        macaroon.add_first_party_caveat(
            "client_redirect_url = %s" % (client_redirect_url,)
        )
        if ui_auth_session_id:
            macaroon.add_first_party_caveat(
                "ui_auth_session_id = %s" % (ui_auth_session_id,)
            )
        now = self._clock.time_msec()
        expiry = now + duration_in_ms
        macaroon.add_first_party_caveat("time < %d" % (expiry,))

        return macaroon.serialize()

    def _verify_oidc_session_token(
        self, session: bytes, state: str
    ) -> Tuple[str, str, Optional[str]]:
        """Verifies and extract an OIDC session token.

        This verifies that a given session token was issued by this homeserver
        and extract the nonce and client_redirect_url caveats.

        Args:
            session: The session token to verify
            state: The state the OIDC provider gave back

        Returns:
            The nonce, client_redirect_url, and ui_auth_session_id for this session
        """
        macaroon = pymacaroons.Macaroon.deserialize(session)

        v = pymacaroons.Verifier()
        v.satisfy_exact("gen = 1")
        v.satisfy_exact("type = session")
        v.satisfy_exact("state = %s" % (state,))
        v.satisfy_general(lambda c: c.startswith("nonce = "))
        v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
        # Sometimes there's a UI auth session ID, it seems to be OK to attempt
        # to always satisfy this.
        v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
        v.satisfy_general(self._verify_expiry)

        v.verify(macaroon, self._macaroon_secret_key)

        # Extract the `nonce`, `client_redirect_url`, and maybe the
        # `ui_auth_session_id` from the token.
        nonce = self._get_value_from_macaroon(macaroon, "nonce")
        client_redirect_url = self._get_value_from_macaroon(
            macaroon, "client_redirect_url"
        )
        try:
            ui_auth_session_id = self._get_value_from_macaroon(
                macaroon, "ui_auth_session_id"
            )  # type: Optional[str]
        except ValueError:
            ui_auth_session_id = None

        return nonce, client_redirect_url, ui_auth_session_id

    def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
        """Extracts a caveat value from a macaroon token.

        Args:
            macaroon: the token
            key: the key of the caveat to extract

        Returns:
            The extracted value

        Raises:
            Exception: if the caveat was not in the macaroon
        """
        prefix = key + " = "
        for caveat in macaroon.caveats:
            if caveat.caveat_id.startswith(prefix):
                return caveat.caveat_id[len(prefix) :]
        raise ValueError("No %s caveat in macaroon" % (key,))

    def _verify_expiry(self, caveat: str) -> bool:
        prefix = "time < "
        if not caveat.startswith(prefix):
            return False
        expiry = int(caveat[len(prefix) :])
        now = self._clock.time_msec()
        return now < expiry

    async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
        """Maps a UserInfo object to a mxid.

        UserInfo should have a claim that uniquely identifies users. This claim
        is usually `sub`, but can be configured with `oidc_config.subject_claim`.
        It is then used as an `external_id`.

        If we don't find the user that way, we should register the user,
        mapping the localpart and the display name from the UserInfo.

        If a user already exists with the mxid we've mapped, raise an exception.

        Args:
            userinfo: an object representing the user
            token: a dict with the tokens obtained from the provider

        Raises:
            MappingException: if there was an error while mapping some properties

        Returns:
            The mxid of the user
        """
        try:
            remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
        except Exception as e:
            raise MappingException(
                "Failed to extract subject from OIDC response: %s" % (e,)
            )

        logger.info(
            "Looking for existing mapping for user %s:%s",
            self._auth_provider_id,
            remote_user_id,
        )

        registered_user_id = await self._datastore.get_user_by_external_id(
            self._auth_provider_id, remote_user_id,
        )

        if registered_user_id is not None:
            logger.info("Found existing mapping %s", registered_user_id)
            return registered_user_id

        try:
            attributes = await self._user_mapping_provider.map_user_attributes(
                userinfo, token
            )
        except Exception as e:
            raise MappingException(
                "Could not extract user attributes from OIDC response: " + str(e)
            )

        logger.debug(
            "Retrieved user attributes from user mapping provider: %r", attributes
        )

        if not attributes["localpart"]:
            raise MappingException("localpart is empty")

        localpart = map_username_to_mxid_localpart(attributes["localpart"])

        user_id = UserID(localpart, self._hostname)
        if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()):
            # This mxid is taken
            raise MappingException(
                "mxid '{}' is already taken".format(user_id.to_string())
            )

        # It's the first time this user is logging in and the mapped mxid was
        # not taken, register the user
        registered_user_id = await self._registration_handler.register_user(
            localpart=localpart, default_display_name=attributes["display_name"],
        )

        await self._datastore.record_user_external_id(
            self._auth_provider_id, remote_user_id, registered_user_id,
        )
        return registered_user_id
예제 #8
0
class OidcProvider:
    """Wraps the config for a single OIDC IdentityProvider

    Provides methods for handling redirect requests and callbacks via that particular
    IdP.
    """
    def __init__(
        self,
        hs: "HomeServer",
        token_generator: "OidcSessionTokenGenerator",
        provider: OidcProviderConfig,
    ):
        self._store = hs.get_datastore()

        self._token_generator = token_generator

        self._callback_url = hs.config.oidc_callback_url  # type: str

        self._scopes = provider.scopes
        self._user_profile_method = provider.user_profile_method
        self._client_auth = ClientAuth(
            provider.client_id,
            provider.client_secret,
            provider.client_auth_method,
        )  # type: ClientAuth
        self._client_auth_method = provider.client_auth_method
        self._provider_metadata = OpenIDProviderMetadata(
            issuer=provider.issuer,
            authorization_endpoint=provider.authorization_endpoint,
            token_endpoint=provider.token_endpoint,
            userinfo_endpoint=provider.userinfo_endpoint,
            jwks_uri=provider.jwks_uri,
        )  # type: OpenIDProviderMetadata
        self._provider_needs_discovery = provider.discover
        self._user_mapping_provider = provider.user_mapping_provider_class(
            provider.user_mapping_provider_config)
        self._skip_verification = provider.skip_verification
        self._allow_existing_users = provider.allow_existing_users

        self._http_client = hs.get_proxied_http_client()
        self._server_name = hs.config.server_name  # type: str

        # identifier for the external_ids table
        self.idp_id = provider.idp_id

        # user-facing name of this auth provider
        self.idp_name = provider.idp_name

        # MXC URI for icon for this auth provider
        self.idp_icon = provider.idp_icon

        # optional brand identifier for this auth provider
        self.idp_brand = provider.idp_brand

        self._sso_handler = hs.get_sso_handler()

        self._sso_handler.register_identity_provider(self)

    def _validate_metadata(self):
        """Verifies the provider metadata.

        This checks the validity of the currently loaded provider. Not
        everything is checked, only:

          - ``issuer``
          - ``authorization_endpoint``
          - ``token_endpoint``
          - ``response_types_supported`` (checks if "code" is in it)
          - ``jwks_uri``

        Raises:
            ValueError: if something in the provider is not valid
        """
        # Skip verification to allow non-compliant providers (e.g. issuers not running on a secure origin)
        if self._skip_verification is True:
            return

        m = self._provider_metadata
        m.validate_issuer()
        m.validate_authorization_endpoint()
        m.validate_token_endpoint()

        if m.get("token_endpoint_auth_methods_supported") is not None:
            m.validate_token_endpoint_auth_methods_supported()
            if (self._client_auth_method
                    not in m["token_endpoint_auth_methods_supported"]):
                raise ValueError(
                    '"{auth_method}" not in "token_endpoint_auth_methods_supported" ({supported!r})'
                    .format(
                        auth_method=self._client_auth_method,
                        supported=m["token_endpoint_auth_methods_supported"],
                    ))

        if m.get("response_types_supported") is not None:
            m.validate_response_types_supported()

            if "code" not in m["response_types_supported"]:
                raise ValueError(
                    '"code" not in "response_types_supported" (%r)' %
                    (m["response_types_supported"], ))

        # Ensure there's a userinfo endpoint to fetch from if it is required.
        if self._uses_userinfo:
            if m.get("userinfo_endpoint") is None:
                raise ValueError(
                    'provider has no "userinfo_endpoint", even though it is required'
                )
        else:
            # If we're not using userinfo, we need a valid jwks to validate the ID token
            if m.get("jwks") is None:
                if m.get("jwks_uri") is not None:
                    m.validate_jwks_uri()
                else:
                    raise ValueError('"jwks_uri" must be set')

    @property
    def _uses_userinfo(self) -> bool:
        """Returns True if the ``userinfo_endpoint`` should be used.

        This is based on the requested scopes: if the scopes include
        ``openid``, the provider should give use an ID token containing the
        user information. If not, we should fetch them using the
        ``access_token`` with the ``userinfo_endpoint``.
        """

        return ("openid" not in self._scopes
                or self._user_profile_method == "userinfo_endpoint")

    async def load_metadata(self) -> OpenIDProviderMetadata:
        """Load and validate the provider metadata.

        The values metadatas are discovered if ``oidc_config.discovery`` is
        ``True`` and then cached.

        Raises:
            ValueError: if something in the provider is not valid

        Returns:
            The provider's metadata.
        """
        # If we are using the OpenID Discovery documents, it needs to be loaded once
        # FIXME: should there be a lock here?
        if self._provider_needs_discovery:
            url = get_well_known_url(self._provider_metadata["issuer"],
                                     external=True)
            metadata_response = await self._http_client.get_json(url)
            # TODO: maybe update the other way around to let user override some values?
            self._provider_metadata.update(metadata_response)
            self._provider_needs_discovery = False

        self._validate_metadata()

        return self._provider_metadata

    async def load_jwks(self, force: bool = False) -> JWKS:
        """Load the JSON Web Key Set used to sign ID tokens.

        If we're not using the ``userinfo_endpoint``, user infos are extracted
        from the ID token, which is a JWT signed by keys given by the provider.
        The keys are then cached.

        Args:
            force: Force reloading the keys.

        Returns:
            The key set

            Looks like this::

                {
                    'keys': [
                        {
                            'kid': 'abcdef',
                            'kty': 'RSA',
                            'alg': 'RS256',
                            'use': 'sig',
                            'e': 'XXXX',
                            'n': 'XXXX',
                        }
                    ]
                }
        """
        if self._uses_userinfo:
            # We're not using jwt signing, return an empty jwk set
            return {"keys": []}

        # First check if the JWKS are loaded in the provider metadata.
        # It can happen either if the provider gives its JWKS in the discovery
        # document directly or if it was already loaded once.
        metadata = await self.load_metadata()
        jwk_set = metadata.get("jwks")
        if jwk_set is not None and not force:
            return jwk_set

        # Loading the JWKS using the `jwks_uri` metadata
        uri = metadata.get("jwks_uri")
        if not uri:
            raise RuntimeError('Missing "jwks_uri" in metadata')

        jwk_set = await self._http_client.get_json(uri)

        # Caching the JWKS in the provider's metadata
        self._provider_metadata["jwks"] = jwk_set
        return jwk_set

    async def _exchange_code(self, code: str) -> Token:
        """Exchange an authorization code for a token.

        This calls the ``token_endpoint`` with the authorization code we
        received in the callback to exchange it for a token. The call uses the
        ``ClientAuth`` to authenticate with the client with its ID and secret.

        See:
           https://tools.ietf.org/html/rfc6749#section-3.2
           https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint

        Args:
            code: The authorization code we got from the callback.

        Returns:
            A dict containing various tokens.

            May look like this::

                {
                    'token_type': 'bearer',
                    'access_token': 'abcdef',
                    'expires_in': 3599,
                    'id_token': 'ghijkl',
                    'refresh_token': 'mnopqr',
                }

        Raises:
            OidcError: when the ``token_endpoint`` returned an error.
        """
        metadata = await self.load_metadata()
        token_endpoint = metadata.get("token_endpoint")
        headers = {
            "Content-Type": "application/x-www-form-urlencoded",
            "User-Agent": self._http_client.user_agent,
            "Accept": "application/json",
        }

        args = {
            "grant_type": "authorization_code",
            "code": code,
            "redirect_uri": self._callback_url,
        }
        body = urlencode(args, True)

        # Fill the body/headers with credentials
        uri, headers, body = self._client_auth.prepare(method="POST",
                                                       uri=token_endpoint,
                                                       headers=headers,
                                                       body=body)
        headers = {k: [v] for (k, v) in headers.items()}

        # Do the actual request
        # We're not using the SimpleHttpClient util methods as we don't want to
        # check the HTTP status code and we do the body encoding ourself.
        response = await self._http_client.request(
            method="POST",
            uri=uri,
            data=body.encode("utf-8"),
            headers=headers,
        )

        # This is used in multiple error messages below
        status = "{code} {phrase}".format(
            code=response.code, phrase=response.phrase.decode("utf-8"))

        resp_body = await make_deferred_yieldable(readBody(response))

        if response.code >= 500:
            # In case of a server error, we should first try to decode the body
            # and check for an error field. If not, we respond with a generic
            # error message.
            try:
                resp = json_decoder.decode(resp_body.decode("utf-8"))
                error = resp["error"]
                description = resp.get("error_description", error)
            except (ValueError, KeyError):
                # Catch ValueError for the JSON decoding and KeyError for the "error" field
                error = "server_error"
                description = ((
                    'Authorization server responded with a "{status}" error '
                    "while exchanging the authorization code.").format(
                        status=status), )

            raise OidcError(error, description)

        # Since it is a not a 5xx code, body should be a valid JSON. It will
        # raise if not.
        resp = json_decoder.decode(resp_body.decode("utf-8"))

        if "error" in resp:
            error = resp["error"]
            # In case the authorization server responded with an error field,
            # it should be a 4xx code. If not, warn about it but don't do
            # anything special and report the original error message.
            if response.code < 400:
                logger.debug("Invalid response from the authorization server: "
                             'responded with a "{status}" '
                             "but body has an error field: {error!r}".format(
                                 status=status, error=resp["error"]))

            description = resp.get("error_description", error)
            raise OidcError(error, description)

        # Now, this should not be an error. According to RFC6749 sec 5.1, it
        # should be a 200 code. We're a bit more flexible than that, and will
        # only throw on a 4xx code.
        if response.code >= 400:
            description = (
                'Authorization server responded with a "{status}" error '
                'but did not include an "error" field in its response.'.format(
                    status=status))
            logger.warning(description)
            # Body was still valid JSON. Might be useful to log it for debugging.
            logger.warning(
                "Code exchange response: {resp!r}".format(resp=resp))
            raise OidcError("server_error", description)

        return resp

    async def _fetch_userinfo(self, token: Token) -> UserInfo:
        """Fetch user information from the ``userinfo_endpoint``.

        Args:
            token: the token given by the ``token_endpoint``.
                Must include an ``access_token`` field.

        Returns:
            UserInfo: an object representing the user.
        """
        metadata = await self.load_metadata()

        resp = await self._http_client.get_json(
            metadata["userinfo_endpoint"],
            headers={
                "Authorization": ["Bearer {}".format(token["access_token"])]
            },
        )

        return UserInfo(resp)

    async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
        """Return an instance of UserInfo from token's ``id_token``.

        Args:
            token: the token given by the ``token_endpoint``.
                Must include an ``id_token`` field.
            nonce: the nonce value originally sent in the initial authorization
                request. This value should match the one inside the token.

        Returns:
            An object representing the user.
        """
        metadata = await self.load_metadata()
        claims_params = {
            "nonce": nonce,
            "client_id": self._client_auth.client_id,
        }
        if "access_token" in token:
            # If we got an `access_token`, there should be an `at_hash` claim
            # in the `id_token` that we can check against.
            claims_params["access_token"] = token["access_token"]
            claims_cls = CodeIDToken
        else:
            claims_cls = ImplicitIDToken

        alg_values = metadata.get("id_token_signing_alg_values_supported",
                                  ["RS256"])

        jwt = JsonWebToken(alg_values)

        claim_options = {"iss": {"values": [metadata["issuer"]]}}

        # Try to decode the keys in cache first, then retry by forcing the keys
        # to be reloaded
        jwk_set = await self.load_jwks()
        try:
            claims = jwt.decode(
                token["id_token"],
                key=jwk_set,
                claims_cls=claims_cls,
                claims_options=claim_options,
                claims_params=claims_params,
            )
        except ValueError:
            logger.info("Reloading JWKS after decode error")
            jwk_set = await self.load_jwks(force=True
                                           )  # try reloading the jwks
            claims = jwt.decode(
                token["id_token"],
                key=jwk_set,
                claims_cls=claims_cls,
                claims_options=claim_options,
                claims_params=claims_params,
            )

        claims.validate(leeway=120)  # allows 2 min of clock skew
        return UserInfo(claims)

    async def handle_redirect_request(
        self,
        request: SynapseRequest,
        client_redirect_url: Optional[bytes],
        ui_auth_session_id: Optional[str] = None,
    ) -> str:
        """Handle an incoming request to /login/sso/redirect

        It returns a redirect to the authorization endpoint with a few
        parameters:

          - ``client_id``: the client ID set in ``oidc_config.client_id``
          - ``response_type``: ``code``
          - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/client/oidc/callback``
          - ``scope``: the list of scopes set in ``oidc_config.scopes``
          - ``state``: a random string
          - ``nonce``: a random string

        In addition generating a redirect URL, we are setting a cookie with
        a signed macaroon token containing the state, the nonce and the
        client_redirect_url params. Those are then checked when the client
        comes back from the provider.

        Args:
            request: the incoming request from the browser.
                We'll respond to it with a redirect and a cookie.
            client_redirect_url: the URL that we should redirect the client to
                when everything is done (or None for UI Auth)
            ui_auth_session_id: The session ID of the ongoing UI Auth (or
                None if this is a login).

        Returns:
            The redirect URL to the authorization endpoint.

        """

        state = generate_token()
        nonce = generate_token()

        if not client_redirect_url:
            client_redirect_url = b""

        cookie = self._token_generator.generate_oidc_session_token(
            state=state,
            session_data=OidcSessionData(
                idp_id=self.idp_id,
                nonce=nonce,
                client_redirect_url=client_redirect_url.decode(),
                ui_auth_session_id=ui_auth_session_id,
            ),
        )
        request.addCookie(
            SESSION_COOKIE_NAME,
            cookie,
            path="/_synapse/client/oidc",
            max_age="3600",
            httpOnly=True,
            sameSite="lax",
        )

        metadata = await self.load_metadata()
        authorization_endpoint = metadata.get("authorization_endpoint")
        return prepare_grant_uri(
            authorization_endpoint,
            client_id=self._client_auth.client_id,
            response_type="code",
            redirect_uri=self._callback_url,
            scope=self._scopes,
            state=state,
            nonce=nonce,
        )

    async def handle_oidc_callback(self, request: SynapseRequest,
                                   session_data: "OidcSessionData",
                                   code: str) -> None:
        """Handle an incoming request to /_synapse/client/oidc/callback

        By this time we have already validated the session on the synapse side, and
        now need to do the provider-specific operations. This includes:

          - exchange the code with the provider using the ``token_endpoint`` (see
            ``_exchange_code``)
          - once we have the token, use it to either extract the UserInfo from
            the ``id_token`` (``_parse_id_token``), or use the ``access_token``
            to fetch UserInfo from the ``userinfo_endpoint``
            (``_fetch_userinfo``)
          - map those UserInfo to a Matrix user (``_map_userinfo_to_user``) and
            finish the login

        Args:
            request: the incoming request from the browser.
            session_data: the session data, extracted from our cookie
            code: The authorization code we got from the callback.
        """
        # Exchange the code with the provider
        try:
            logger.debug("Exchanging code")
            token = await self._exchange_code(code)
        except OidcError as e:
            logger.exception("Could not exchange code")
            self._sso_handler.render_error(request, e.error,
                                           e.error_description)
            return

        logger.debug("Successfully obtained OAuth2 access token")

        # Now that we have a token, get the userinfo, either by decoding the
        # `id_token` or by fetching the `userinfo_endpoint`.
        if self._uses_userinfo:
            logger.debug("Fetching userinfo")
            try:
                userinfo = await self._fetch_userinfo(token)
            except Exception as e:
                logger.exception("Could not fetch userinfo")
                self._sso_handler.render_error(request, "fetch_error", str(e))
                return
        else:
            logger.debug("Extracting userinfo from id_token")
            try:
                userinfo = await self._parse_id_token(token,
                                                      nonce=session_data.nonce)
            except Exception as e:
                logger.exception("Invalid id_token")
                self._sso_handler.render_error(request, "invalid_token",
                                               str(e))
                return

        # first check if we're doing a UIA
        if session_data.ui_auth_session_id:
            try:
                remote_user_id = self._remote_id_from_userinfo(userinfo)
            except Exception as e:
                logger.exception("Could not extract remote user id")
                self._sso_handler.render_error(request, "mapping_error",
                                               str(e))
                return

            return await self._sso_handler.complete_sso_ui_auth_request(
                self.idp_id, remote_user_id, session_data.ui_auth_session_id,
                request)

        # otherwise, it's a login

        # Call the mapper to register/login the user
        try:
            await self._complete_oidc_login(userinfo, token, request,
                                            session_data.client_redirect_url)
        except MappingException as e:
            logger.exception("Could not map user")
            self._sso_handler.render_error(request, "mapping_error", str(e))

    async def _complete_oidc_login(
        self,
        userinfo: UserInfo,
        token: Token,
        request: SynapseRequest,
        client_redirect_url: str,
    ) -> None:
        """Given a UserInfo response, complete the login flow

        UserInfo should have a claim that uniquely identifies users. This claim
        is usually `sub`, but can be configured with `oidc_config.subject_claim`.
        It is then used as an `external_id`.

        If we don't find the user that way, we should register the user,
        mapping the localpart and the display name from the UserInfo.

        If a user already exists with the mxid we've mapped and allow_existing_users
        is disabled, raise an exception.

        Otherwise, render a redirect back to the client_redirect_url with a loginToken.

        Args:
            userinfo: an object representing the user
            token: a dict with the tokens obtained from the provider
            request: The request to respond to
            client_redirect_url: The redirect URL passed in by the client.

        Raises:
            MappingException: if there was an error while mapping some properties
        """
        try:
            remote_user_id = self._remote_id_from_userinfo(userinfo)
        except Exception as e:
            raise MappingException(
                "Failed to extract subject from OIDC response: %s" % (e, ))

        # Older mapping providers don't accept the `failures` argument, so we
        # try and detect support.
        mapper_signature = inspect.signature(
            self._user_mapping_provider.map_user_attributes)
        supports_failures = "failures" in mapper_signature.parameters

        async def oidc_response_to_user_attributes(
                failures: int) -> UserAttributes:
            """
            Call the mapping provider to map the OIDC userinfo and token to user attributes.

            This is backwards compatibility for abstraction for the SSO handler.
            """
            if supports_failures:
                attributes = await self._user_mapping_provider.map_user_attributes(
                    userinfo, token, failures)
            else:
                # If the mapping provider does not support processing failures,
                # do not continually generate the same Matrix ID since it will
                # continue to already be in use. Note that the error raised is
                # arbitrary and will get turned into a MappingException.
                if failures:
                    raise MappingException(
                        "Mapping provider does not support de-duplicating Matrix IDs"
                    )

                attributes = await self._user_mapping_provider.map_user_attributes(  # type: ignore
                    userinfo, token)

            return UserAttributes(**attributes)

        async def grandfather_existing_users() -> Optional[str]:
            if self._allow_existing_users:
                # If allowing existing users we want to generate a single localpart
                # and attempt to match it.
                attributes = await oidc_response_to_user_attributes(failures=0)

                user_id = UserID(attributes.localpart,
                                 self._server_name).to_string()
                users = await self._store.get_users_by_id_case_insensitive(
                    user_id)
                if users:
                    # If an existing matrix ID is returned, then use it.
                    if len(users) == 1:
                        previously_registered_user_id = next(iter(users))
                    elif user_id in users:
                        previously_registered_user_id = user_id
                    else:
                        # Do not attempt to continue generating Matrix IDs.
                        raise MappingException(
                            "Attempted to login as '{}' but it matches more than one user inexactly: {}"
                            .format(user_id, users))

                    return previously_registered_user_id

            return None

        # Mapping providers might not have get_extra_attributes: only call this
        # method if it exists.
        extra_attributes = None
        get_extra_attributes = getattr(self._user_mapping_provider,
                                       "get_extra_attributes", None)
        if get_extra_attributes:
            extra_attributes = await get_extra_attributes(userinfo, token)

        await self._sso_handler.complete_sso_login_request(
            self.idp_id,
            remote_user_id,
            request,
            client_redirect_url,
            oidc_response_to_user_attributes,
            grandfather_existing_users,
            extra_attributes,
        )

    def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
        """Extract the unique remote id from an OIDC UserInfo block

        Args:
            userinfo: An object representing the user given by the OIDC provider
        Returns:
            remote user id
        """
        remote_user_id = self._user_mapping_provider.get_remote_user_id(
            userinfo)
        # Some OIDC providers use integer IDs, but Synapse expects external IDs
        # to be strings.
        return str(remote_user_id)
예제 #9
0
파일: api.py 프로젝트: fantix/authlib-gino
USERINFO_ENDPOINT = "/userinfo"
JWKS_URI = "/.well-known/jwks.json"

oidc_scheme = JWTBearer(
    AUTHORIZATION_ENDPOINT,
    TOKEN_ENDPOINT,
    scheme_name="oidc",
    scopes=SCOPES,
    auto_error=False,
)
metadata = OpenIDProviderMetadata(
    issuer=config.JWT_ISSUER,
    authorization_endpoint=config.JWT_ISSUER + AUTHORIZATION_ENDPOINT,
    token_endpoint=config.JWT_ISSUER + TOKEN_ENDPOINT,
    userinfo_endpoint=config.JWT_ISSUER + USERINFO_ENDPOINT,
    jwks_uri=config.JWT_ISSUER + JWKS_URI,
    registration_endpoint=None,
    scopes_supported=list(SCOPES.keys()),
    response_types_supported=["code"],
    subject_types_supported=["public"],
    id_token_signing_alg_values_supported=[config.JWT_ALGORITHM],
)
auth = AuthorizationServer(Client.get, save_token, metadata, ClientAuthentication)
metadata["token_endpoint_auth_methods_supported"] = auth.auth_methods
if config.DEBUG:
    from unittest.mock import patch

    mock = {}
    for key in ("issuer", "authorization_endpoint", "token_endpoint", "jwks_uri"):
        mock[key] = metadata[key].replace("http://", "https://")
    with patch.dict(metadata, mock):
        metadata.validate()
예제 #10
0
    def test_validate_jwks_uri(self):
        # required
        metadata = OpenIDProviderMetadata()
        with self.assertRaises(ValueError) as cm:
            metadata.validate_jwks_uri()
            self.assertEqual('"jwks_uri" is required', str(cm.exception))

        metadata = OpenIDProviderMetadata(
            {'jwks_uri': 'http://authlib.org/jwks.json'})
        with self.assertRaises(ValueError) as cm:
            metadata.validate_jwks_uri()
            self.assertIn('https', str(cm.exception))

        metadata = OpenIDProviderMetadata(
            {'jwks_uri': 'https://authlib.org/jwks.json'})
        metadata.validate_jwks_uri()
예제 #11
0
 def _call_contains_invalid_value(self, key, invalid_value):
     metadata = OpenIDProviderMetadata({key: invalid_value})
     with self.assertRaises(ValueError) as cm:
         getattr(metadata, 'validate_' + key)()
         self.assertEqual('"{}" contains invalid values'.format(key),
                          str(cm.exception))
예제 #12
0
 def test_validate_claim_types_supported(self):
     self._call_validate_array('claim_types_supported', ['normal'])
     self._call_contains_invalid_value('claim_types_supported', ['invalid'])
     metadata = OpenIDProviderMetadata()
     self.assertEqual(metadata.claim_types_supported, ['normal'])