Example #1
0
    def start_authz(self, arguments):
        request_data = {
            'scope': [],
            'response_type': [],
            'client_id': None,
            'redirect_uri': None,
            'state': None,
            'response_mode': None,
            'nonce': None,
            'display': None,
            'prompt': [],
            'max_age': None,
            'ui_locales': None,
            'id_token_hint': None,
            'login_hint': None,
            'acr_values': None,
            'claims': '{}'
        }

        # Get the request
        # Step 1: get the get query arguments
        for data in request_data.keys():
            if arguments.get(data, None):
                request_data[data] = arguments[data]

        # This is a workaround for python not understanding the splits we
        # do later
        if request_data['prompt'] == []:
            request_data['prompt'] = None

        for required_arg in ['scope', 'response_type', 'client_id']:
            if request_data[required_arg] is None or \
                    len(request_data[required_arg]) == 0:
                return self._respond_error(
                    request_data, 'invalid_request',
                    'missing required argument %s' % required_arg)

        client = self.cfg.datastore.getClient(request_data['client_id'])
        if not client:
            return self._respond_error(request_data, 'unauthorized_client',
                                       'Unknown client ID')

        request_data['response_type'] = request_data.get('response_type',
                                                         '').split(' ')
        for rtype in request_data['response_type']:
            if rtype not in ['id_token', 'token', 'code']:
                return self._respond_error(
                    request_data, 'unsupported_response_type',
                    'response type %s is not supported' % rtype)

        if request_data['response_type'] != ['code'] and \
                not request_data['nonce']:
            return self._respond_error(request_data, 'invalid_request',
                                       'nonce missing in non-code flow')

        # Step 2: get any provided request or request_uri
        if 'request' in arguments or 'request_uri' in arguments:
            # This is a JWT-encoded request
            if 'request' in arguments and 'request_uri' in arguments:
                return self._respond_error(
                    request_data, 'invalid_request',
                    'both request and request_uri ' + 'provided')

            if 'request' in arguments:
                jwt_object = arguments['request']
            else:
                try:
                    # FIXME: MAY cache this at client registration time and
                    # cache permanently until client registration is changed.
                    jwt_object = requests.get(arguments['request_uri']).text
                except Exception as ex:  # pylint: disable=broad-except
                    self.debug('Unable to get request: %s' % ex)
                    return self._respond_error(request_data, 'invalid_request',
                                               'unable to parse request_uri')

            jwt_request = None
            try:
                # FIXME: Implement decryption
                decoded = JWT(jwt=jwt_object)
                if client['request_object_signing_alg'] != 'none':
                    # Client told us we need to check signature
                    if decoded.token.jose_header['alg'] != \
                            client['request_object_signing_alg']:
                        raise Exception('Invalid algorithm used: %s' %
                                        decoded.token.jose_header['alg'])

                if client['request_object_signing_alg'] == 'none':
                    jwt_request = json.loads(decoded.token.objects['payload'])
                else:
                    keyset = None
                    if client['jwks']:
                        keys = json.loads(client['jkws'])
                    else:
                        keys = requests.get(client['jwks_uri']).json()
                    keyset = JWKSet()
                    for key in keys['keys']:
                        keyset.add(JWK(**key))
                    key = keyset.get_key(decoded.token.jose_header['kid'])
                    decoded = JWT(jwt=jwt_object, key=key)
                    jwt_request = json.loads(decoded.claims)

            except Exception as ex:  # pylint: disable=broad-except
                self.debug('Unable to parse request: %s' % ex)
                return self._respond_error(request_data, 'invalid_request',
                                           'unable to parse request')

            if 'response_type' in jwt_request:
                jwt_request['response_type'] = \
                    jwt_request['response_type'].split(' ')
                if jwt_request['response_type'] != \
                        request_data['response_type']:
                    return self._respond_error(request_data, 'invalid_request',
                                               'response_type does not match')

            if 'client_id' in jwt_request:
                if jwt_request['client_id'] != request_data['client_id']:
                    return self._respond_error(request_data, 'invalid_request',
                                               'client_id does not match')

            for data in request_data.keys():
                if data in jwt_request:
                    request_data[data] = jwt_request[data]

        # Split these options since they are space-separated lists
        for to_split in ['prompt', 'ui_locales', 'acr_values', 'scope']:
            if request_data[to_split] is not None:
                # We know better than pylint in this regard
                # pylint: disable=no-member
                request_data[to_split] = request_data[to_split].split(' ')
            else:
                request_data[to_split] = []

        # Start checking the request
        if request_data['redirect_uri'] is None:
            if len(client['redirect_uris']) != 1:
                return self._respond_error(request_data, 'invalid_request',
                                           'missing redirect_uri')
            else:
                request_data['redirect_uri'] = client['redirect_uris'][0]

        for scope in request_data['scope']:
            if scope not in self.cfg.supported_scopes:
                return self._respond_error(
                    request_data, 'invalid_scope',
                    'unknown scope %s requested' % scope)

        for response_type in request_data['response_type']:
            if response_type not in ['code', 'id_token', 'token']:
                return self._respond_error(
                    request_data, 'unsupported_response_type',
                    'response_type %s is unknown' % response_type)

        if request_data['redirect_uri'] not in client['redirect_uris']:
            raise InvalidRequest('Invalid redirect_uri')

        # Build the "claims" values from scopes
        try:
            request_data['claims'] = json.loads(request_data['claims'])
        except Exception as ex:  # pylint: disable=broad-except
            return self._respond_error(request_data, 'invalid_request',
                                       'claims malformed: %s' % ex)
        if 'userinfo' not in request_data['claims']:
            request_data['claims']['userinfo'] = {}
        if 'id_token' not in request_data['claims']:
            request_data['claims']['id_token'] = {}

        scopes_to_claim = {
            'profile': [
                'name', 'family_name', 'given_name', 'middle_name', 'nickname',
                'preferred_username', 'profile', 'picture', 'website',
                'gender', 'birthdate', 'zoneinfo', 'locale', 'updated_at'
            ],
            'email': ['email', 'email_verified'],
            'address': ['address'],
            'phone': ['phone_number', 'phone_number_verified']
        }
        for scope in scopes_to_claim:
            if scope in request_data['scope']:
                for claim in scopes_to_claim[scope]:
                    if claim not in request_data['claims']:
                        # pylint: disable=invalid-sequence-index
                        request_data['claims']['userinfo'][claim] = None

        # Add claims from extensions
        for n, e in self.cfg.extensions.available().items():
            data = e.get_claims(request_data['scope'])
            self.debug('%s returned %s' % (n, repr(data)))
            for claim in data:
                # pylint: disable=invalid-sequence-index
                request_data['claims']['userinfo'][claim] = None

        # Store data so we can continue with the request
        us = UserSession()
        user = us.get_user()

        returl = '%s/%s/Continue?%s' % (self.basepath, URLROOT,
                                        self.trans.get_GET_arg())
        data = {
            'login_target': client.get('client_name', None),
            'login_return': returl,
            'openidc_stage': 'continue',
            'openidc_request': json.dumps(request_data)
        }

        if request_data['login_hint']:
            data['login_username'] = request_data['login_hint']

        if not data['login_target']:
            data['login_target'] = get_url_hostpart(
                request_data['redirect_uri'])

        # Decide what to do with the request
        if request_data['max_age'] is None:
            request_data['max_age'] = client.get('default_max_age', None)

        needs_auth = True
        if not user.is_anonymous:
            if request_data['max_age'] in [None, 0]:
                needs_auth = False
            else:
                auth_time = us.get_user_attrs()['_auth_time']
                needs_auth = ((int(auth_time) + int(request_data['max_age']))
                              <= int(time.time()))

        if needs_auth or 'login' in request_data['prompt']:
            if 'none' in request_data['prompt']:
                # We were asked not to provide a UI. Answer with false.
                return self._respond_error(request_data, 'login_required',
                                           'user interface required')

            # Either the user wasn't logged in, or we were explicitly
            # asked to re-auth them. Let's do so!
            us.logout(user)

            # Let the user go to auth
            self.trans.store(data)
            redirect = '%s/login?%s' % (self.basepath,
                                        self.trans.get_GET_arg())
            self.debug('Redirecting: %s' % redirect)
            raise cherrypy.HTTPRedirect(redirect)

        # Return error if authz check fails
        authz_check_res = self._authz_stack_check(request_data,
                                                  client, user.name,
                                                  us.get_user_attrs())
        if authz_check_res:
            return authz_check_res

        self.trans.store(data)
        # The user was already signed on, and no request to re-assert its
        # identity. Let's forward directly to /Continue/
        self.debug('Redirecting: %s' % returl)
        raise cherrypy.HTTPRedirect(returl)
Example #2
0
class JWTValidator:
    def __init__(self,
                 jwks_urls: "Optional[Union[str, Collection[str]]]" = None):
        from jwcrypto.jwk import JWKSet

        self.jwks_urls = jwks_urls
        self.keys = JWKSet()
        self.session = None  # type: Optional[ClientSession]

    async def poll(self, poll_interval: timedelta = DEFAULT_POLL_INTERVAL):
        """
        Periodically check for new keys. This coroutine will NEVER terminate naturally, so it should
        not be awaited.
        """
        while True:
            await sleep(poll_interval.total_seconds())

            LOG.debug("JWKS poller polling for new keys...")
            await shield(self._load_new_keys())
            LOG.debug("...JWKS poll complete.")

    async def decode_claims(self, token: str) -> "JWTClaims":
        from jwcrypto.jwt import JWT

        LOG.debug("Verifying token: %r", token)
        jwt = JWT(jwt=token)
        key = jwt.token.jose_header["kid"]
        await self.get_key(key)

        jwt = JWT(jwt=token, key=self.keys)
        return json.loads(jwt.claims)

    async def get_key(self, kid: str) -> "Optional[JWK]":
        """
        Retrieve a key for a given ``kid``. If the key could not be found, the keystore is
        refreshed, and if a key is discovered through that process, it is returned. May return
        ``None`` if even after refreshing keys from the remote JWKS source, a key for the given
        ``kid`` could not be found.
        """
        key = self.keys.get_key(kid)
        if key is None:
            await shield(self._load_new_keys())
            key = self.keys.get_key(kid)

        return key

    def export_all_keys(self) -> "str":
        """
        Return a JSON string that formats all public keys currently in the store as JWKS.
        """
        return self.keys.export(private_keys=False)

    async def _load_new_keys(self):
        from jwcrypto.jwk import JWK

        for url in self.jwks_urls:
            try:
                if self.session is None:
                    self.session = ClientSession()

                async with self.session.get(url,
                                            allow_redirects=False) as response:
                    jwks_json = await response.json()

                # ``JWKSet.import_keyset`` suffers from a few critical flaws that make it
                # unusable for us:
                #   1) ``import_keyset`` internally adds keys to a set, which is semantically
                #      correct. However, because JWK has no __eq__ or __hash__ implementation,
                #      EVERY key is repeatedly appended to the set rather than duplicates
                #      getting filtered out.
                #   2) The previous point necessitates that we pre-process the data to filter out
                #      keys that we do not wish to add, thereby requiring us to parse the JSON
                #      and read the payload. ``import_keyset`` expects its argument as a serialized
                #      JSON string, which it promptly parses back into a data structure.
                #
                # We process JWKS endpoints ourselves and selectively add keys directly to the
                # implementation. We are NOT reimplementing ``import_keyset``'s functionality of
                # carrying additional non-``keys`` fields into the ``JWKSet`` object. We are
                # generating JWKS data ourselves, and always generate data that contains only the
                # single top-level property of ``keys`` so this has no impact on us.
                jwks_keys = jwks_json.get("keys")
                if jwks_keys is None:
                    LOG.warning(
                        'The JWKS endpoint did not return a "keys" property, so no new '
                        "keys were added. This will be retried")
                    continue

                existing_kids = {k.key_id for k in self.keys}
                for jwk_dict in jwks_keys:
                    kid = jwk_dict.get("kid")
                    if kid is None:
                        LOG.warning(
                            'The JWKS endpoint contained a key without a "kid" field. '
                            "It will be dropped.")
                    elif kid in existing_kids:
                        LOG.debug(
                            "We already know about kid %s, so the new value will be "
                            "ignored.",
                            kid,
                        )
                    else:
                        jwk = None
                        try:
                            jwk = JWK(**jwk_dict)
                        except Exception:  # noqa
                            LOG.exception(
                                f"The JWK identified by {kid} could not be parsed."
                            )

                        if jwk is not None:
                            try:
                                self.keys.add(jwk)
                            except Exception:  # noqa
                                LOG.exception(
                                    f"The JWK identified by {kid} could not be added."
                                )

            except Exception as ex:  # noqa
                # Do NOT log these with full stack traces because they're actually fairly common,
                # particularly at startup when user-service has yet to start. Merely logging the
                # text of the exception without a scary stack trace is sufficient.
                LOG.warning("Error when checking url %r for new keys: %s", url,
                            ex)

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        if self.session is not None:
            await self.session.close()