def fetchToken(self, **kwargs): """Fetch token :return: dict """ try: self.fetch_access_token(self.get_metadata("token_endpoint"), **kwargs) except Exception as e: self.log.exception(e) return S_ERROR(repr(e)) self.token["client_id"] = self.client_id self.token["provider"] = self.name return S_OK(OAuth2Token(dict(self.token)))
def getTokensByUserID(self, userID): """Return tokens for user ID :param str userID: user ID that return identity provider :return: S_OK(list)/S_ERROR() -- tokens as OAuth2Token objects """ session = self.session() try: tokens = session.query(Token).filter(Token.user_id == userID).all() except NoResultFound: return self.__result(session, S_OK([])) except Exception as e: return self.__result(session, S_ERROR(str(e))) return self.__result(session, S_OK([OAuth2Token(self.__rowToDict(t)) for t in tokens]))
def getCredentialByRefreshToken(self, tokenID): """Get refresh token credential :param str tokenID: refresh token ID :return: S_OK(dict)/S_ERROR() """ session = self.session() try: token = session.query(RefreshToken).filter( RefreshToken.jti == tokenID).first() session.query(RefreshToken).filter( RefreshToken.jti == tokenID).delete() except Exception as e: return self.__result(session, S_ERROR(str(e))) return self.__result( session, S_OK(OAuth2Token(self.__rowToDict(token)) if token else None))
def getTokenForUserProvider(self, userID, provider): """Get token for user ID and identity provider name :param str userID: user ID :param str provider: provider name :return: S_OK(OAuth2Token)/S_ERROR() -- return an OAuth2Token object, which is also a dict """ session = self.session() try: token = ( session.query(Token) .filter(Token.rt_expires_at > time.time()) .filter(Token.user_id == userID) .filter(Token.provider == provider) .first() ) except Exception as e: return self.__result(session, S_ERROR(str(e))) return self.__result(session, S_OK(OAuth2Token(self.__rowToDict(token)) if token else None))
def researchGroup(self, payload=None, token=None): """Research group :param str payload: token payload :param str token: access token :return: S_OK(dict)/S_ERROR() """ if not token: token = self.token if not payload and token: payload = OAuth2Token(token).get_payload() credDict = self.parseBasic(payload) if not credDict.get("DIRACGroups"): credDict.update(self.parseEduperson(payload)) if credDict.get("DIRACGroups"): self.log.debug("Found next groups:", ", ".join(credDict["DIRACGroups"])) credDict["group"] = credDict["DIRACGroups"][0] return S_OK(credDict)
class OAuth2IdProvider(IdProvider, OAuth2Session): """Base class to describe the configuration of the OAuth2 client of the corresponding provider.""" JWKS_REFRESH_RATE = 24 * 3600 METADATA_REFRESH_RATE = 24 * 3600 def __init__(self, **kwargs): """Initialization""" IdProvider.__init__(self, **kwargs) OAuth2Session.__init__(self, **kwargs) self.metadata_fetch_last = 0 self.issuer = self.metadata["issuer"] self.scope = self.scope or "" self.jwks = kwargs.get("jwks") self.verify = kwargs.get("verify", True) # Decide if need to check CAs self.token_placement = kwargs.get("token_placement", "header") self.code_challenge_method = "S256" # self.token_endpoint_auth_method = kwargs.get('token_endpoint_auth_method') #, 'client_secret_post') self.server_metadata_url = kwargs.get("server_metadata_url", get_well_known_url(self.metadata["issuer"], True)) self.jwks_fetch_last = time.time() - self.JWKS_REFRESH_RATE self.metadata_fetch_last = time.time() - self.METADATA_REFRESH_RATE self.log.debug( '"%s" OAuth2 IdP initialization done:' % self.name, "\nclient_id: %s\nclient_secret: %s\nmetadata:\n%s" % (self.client_id, self.client_secret, pprint.pformat(self.metadata)), ) def get_metadata(self, option=None): """Get metadata :param str option: option :return: option value """ if not self.metadata.get(option): self.fetch_metadata() return self.metadata.get(option) @gMetadata def fetch_metadata(self): """Fetch metada""" if self.metadata_fetch_last < (time.time() - self.METADATA_REFRESH_RATE): data = self.get(self.server_metadata_url, withhold_token=True).json() self.metadata.update(data) self.metadata_fetch_last = time.time() @gJWKs def updateJWKs(self): """Update JWKs""" if self.jwks_fetch_last < (time.time() - self.JWKS_REFRESH_RATE): try: self.jwks = self.get(self.get_metadata("jwks_uri"), withhold_token=True).json() self.jwks_fetch_last = time.time() return S_OK(self.jwks) except Exception as e: self.log.exception(e) return S_ERROR("Error %s" % repr(e)) return S_OK() def verifyToken(self, accessToken=None, jwks=None): """Verify access token :param str accessToken: access token :param dict jwks: JWKs :return: dict """ # Define an access token if not accessToken: accessToken = self.token["access_token"] # Renew a JWKs of an identity provider if needed if not jwks: result = self.updateJWKs() if not result["OK"]: return result jwks = self.jwks if not jwks: return S_ERROR("JWKs not found.") # Try to decode and verify an access token self.log.debug("Try to decode token %s with JWKs:\n" % accessToken, pprint.pformat(jwks)) try: return S_OK(jwt.decode(accessToken, JsonWebKey.import_key_set(jwks))) except Exception as e: self.log.exception(e) return S_ERROR(repr(e)) @gRefreshToken def refreshToken(self, refresh_token=None, group=None, **kwargs): """Refresh token :param str token: refresh_token :param str group: DIRAC group :return: dict """ if group: # If group set add group scopes to request if not (groupScopes := self.getGroupScopes(group)): return S_ERROR(f"No scope found for {group}") kwargs.update(dict(scope=list_to_scope(groupScopes))) if not refresh_token: refresh_token = self.token.get("refresh_token") try: token = self.refresh_token(self.get_metadata("token_endpoint"), refresh_token=refresh_token, **kwargs) return S_OK(OAuth2Token(dict(token))) except Exception as e: self.log.exception(e) return S_ERROR(repr(e))
scope = scope or scope_to_list(self.scope) if group: if not (groupScopes := self.getGroupScopes(group)): return S_ERROR(f"No scope found for {group}") scope = list(set(scope + groupScopes)) scope = list_to_scope(scope) try: token = self.exchange_token( self.get_metadata("token_endpoint"), subject_token=self.token["access_token"], subject_token_type="urn:ietf:params:oauth:token-type:access_token", scope=scope, ) if not token: return S_ERROR("Cannot exchange token with %s scope." % scope) return S_OK(OAuth2Token(dict(token))) except Exception as e: self.log.exception(e) return S_ERROR("Cannot exchange token with %s scope: %s" % (scope, repr(e))) def researchGroup(self, payload=None, token=None): """Research group :param str payload: token payload :param str token: access token :return: S_OK(dict)/S_ERROR() """ if not token: token = self.token