class MyOAuth2Session(object): """ Establish an OAuth2 "session" using raw oauthlib. For authorization_code and implicit flows, this will open a browser to log the user in. In order to test token refresh without having to wait, you can set `expire_faster` to a smaller number. """ # TODO: Consider replacing all this with requests-oauthlib def __init__(self, oauth_server=None, client_id=None, client_secret=None, redirect_url=None, grant=None, scopes=None, refresh_token=None, expire_faster=None): """ establish an OAuth 2 session (acquire Access Token, etc) :param oauth_server: base URL of the oauth2 server :param client_id: OAuth 2 client ID :param client_secret: OAuth 2 client secret :param redirect_url: OAuth 2 client's registers redirect URL :param grant: type: 'authorization_code', etc. :param scopes: list of requested scopes :param refresh_token: OAuth 2 refresh token for 'refresh_token' grant :param expire_faster: Seconds until access token expires. Used to override the default expires_in. """ self.oauth_server = oauth_server self.client_id = client_id self.client_secret = client_secret self.redirect_url = redirect_url self.grant = grant self.scopes = scopes self.refresh_token = refresh_token self.expire_faster = expire_faster #: generate Authorization Basic header with client ID and secret. self.oauth_auth = requests.auth.HTTPBasicAuth(self.client_id, self.client_secret) #: oauthlib client is set in one of do_authorization_code, do_implicit, etc. self.oauth_client = None # get oauth 2.0 endpoints by asking the AS for them: r = requests.get(oauth_server + '/.well-known/openid-configuration') if r.status_code == 200: self.oauth_endpoints = r.json() else: raise ConnectionError( "failed to get OAuth 2 endpoints: {} {}: {}".format( r.status_code, r.reason, r.content)) # if we are testing with non-TLS then tell oauthlib that's OK if not redirect_url.startswith('https'): os.environ['OAUTHLIB_INSECURE_TRANSPORT'] = '1' class BearerAuth(requests.auth.AuthBase): """ Creates a callable that sets "Authorization: Bearer <access_token>" header for requests. Usage example:: response = requests.get(url, body=body, data=data, headers=headers, auth=BearerAuth(access_token)) """ def __init__(self, access_token=None): if access_token: self.access_token = access_token def __call__(self, r): r.headers['Authorization'] = 'Bearer ' + self.access_token return r def set_tokens(self): """ parse out the various tokens from the OAuth 2 Token response """ if self.oauth_tokens: self.access_token = self.oauth_tokens[ 'access_token'] if 'access_token' in self.oauth_tokens else None # possibly replace the current refresh token with a new one due to refresh token rolling policy in the AS. self.refresh_token = self.oauth_tokens[ 'refresh_token'] if 'refresh_token' in self.oauth_tokens else None self.id_token = self.oauth_tokens[ 'id_token'] if 'id_token' in self.oauth_tokens else None self.expires_in = self.oauth_tokens[ 'expires_in'] if 'expires_in' in self.oauth_tokens else None if self.expire_faster and self.expires_in and self.expire_faster < self.expires_in: self.expires_in = self.expire_faster self.expires_at = self.expires_in + time.time( ) if self.expires_in else None else: self.access_token = self.refresh_token = self.id_token = self.expires_in = self.expires_at = None def do_authorization_code(self): """ login to OAuth 2 using Authorization Code grant """ self.oauth_client = WebApplicationClient(self.client_id) (authorize_url, headers, body) = self.oauth_client.prepare_authorization_request( self.oauth_endpoints['authorization_endpoint'], redirect_url=self.redirect_url, scope=self.scopes) status, code_request_url = self._redirect_server( self.redirect_url, authorize_url) if not code_request_url: input("Copy URL to clipboard and then hit enter:") code_request_url = pyperclip.paste() if status != 200 or 'error_description' in code_request_url: r = urlparse(code_request_url) errors = parse_qs(r.query) if 'error' in errors and 'error_description' in errors: msg = "{}: {}".format(errors['error'][0], errors['error_description'][0]) raise TimeoutError(msg) if status == 408 else RuntimeError(msg) else: msg = pformat(errors) raise TimeoutError(msg) if status == 408 else RuntimeError(msg) (url, headers, body) = self.oauth_client.prepare_token_request( self.oauth_endpoints['token_endpoint'], authorization_response=code_request_url, redirect_url=self.redirect_url) token_response = requests.post(url, headers=headers, data=body, auth=self.oauth_auth) if token_response.status_code != 200: raise PermissionError(self._format_response_error(token_response)) self.oauth_tokens = token_response.json() self.set_tokens() def do_refresh_token(self): """ Use the refresh token to get new tokens. Refresh tokens are only used with the Authorization Code grant. Can be called either from the get-go with a supplied refresh_token or later on to refresh an existing access_token. Not that depending on how the OAuth 2.0 AS is configured, the refresh token is rotated after each use (or some amount of time). This can mitigate replays of a stolen refresh token but will raise an exception if you try to reuse an invalid refresh token:: PermissionError: 400: invalid_grant: unknown, invalid, or expired refresh token """ if not self.refresh_token: raise ValueError( "no refresh token supplied for `refresh_token` grant") if self.oauth_client is None: self.oauth_client = WebApplicationClient(self.client_id) (token_url, headers, body) = self.oauth_client.prepare_refresh_token_request( self.oauth_endpoints['token_endpoint'], refresh_token=self.refresh_token, scope=self.scopes) token_response = requests.post(token_url, headers=headers, data=body, auth=self.oauth_auth) if token_response.status_code != 200: raise PermissionError(self._format_response_error(token_response)) self.oauth_tokens = token_response.json() self.set_tokens() def do_implicit(self): """ Implicit grant """ self.oauth_client = MobileApplicationClient(self.client_id) (token_url, headers, body) = self.oauth_client.prepare_authorization_request( self.oauth_endpoints['authorization_endpoint'], redirect_url=self.redirect_url, scope=self.scopes, ) # For openid scope, have to add id_token response_type and supply a nonce: # https://www.pingidentity.com/content/developer/en/resources/openid-connect-developers-guide/implicit-client # -profile.html # "Note: To mitigate replay attacks, a nonce value must be included to associate a client session with an # id_token. The client must generate a random value associated with the current session and pass this # along with the request. This nonce value will be returned with the id_token and must be verified to be # the same as the value provided in the initial request." # # oauthlib `lacks client features<https://github.com/oauthlib/oauthlib/issues/615>`_ for OIDC. if 'openid' in self.scopes: self.nonce = uuid.uuid4().hex # TODO: fix this to something more robust or extend oauthlib token_url = token_url.replace( 'response_type=token', 'response_type=token+id_token&nonce={}'.format(self.nonce)) # for the implicit grant type, the parameters are provided to the browser as #fragments rather than # ?query-parameters so they are never available to the redirect server. The user must copy the # browser self._redirect_server( self.redirect_url, token_url, message= "Copy this URL to the clipboard and hit enter on the console. " "Then you can close this window.") input("\n\nCopy URL to clipboard and then hit enter:") token_response_url = pyperclip.paste() self.oauth_tokens = self.oauth_client.parse_request_uri_response( token_response_url) self.set_tokens() def do_client_credentials(self): """ Client credentials grant """ self.oauth_client = BackendApplicationClient(self.client_id) (token_url, headers, body) = self.oauth_client.prepare_token_request( self.oauth_endpoints['token_endpoint'], scope=self.scopes) token_response = requests.post(token_url, headers=headers, data=body, auth=self.oauth_auth) if token_response.status_code != 200: raise PermissionError(self._format_response_error(token_response)) self.oauth_tokens = token_response.json() self.set_tokens() @staticmethod def _format_response_error(response): """ make a pretty error response message:: status_code: reason error: error_description :param response: request.request() response :return: string formatted with status_code, reason and json-parsed 'error' and 'error_description' (if json) """ try: content = response.json() error = content['error'] + ': ' if 'error' in content else '' error_msg = error + content[ 'error_description'] if 'error_description' in content else content except: error_msg = response.reason + ':' + response.content.decode() return "{}: {}".format(response.status_code, error_msg) def _redirect_server( self, redirect_url, authorize_url, message='Response received. You can close this window.'): """ Run an http server thread to catch the OAuth 2 redirect and then exit. :param redirect_url: redirect_url :param authorize_url: authorization request :param message: success message to display in the browser :return: status, redirected request path containing parameters """ class Handler(http.server.BaseHTTPRequestHandler): def log_message(self, format, *args): """ silence the default log_message that http.server prints for each request """ pass def do_GET(self): """ implement http GET method. """ # the browser may ask for stupid things like /favicon.ico # so check to make sure it's the redirect_uri that we expect if self.path.startswith(server.redirect_path_prefix): server.path = self.path server.semaphore.set() response = message.encode() else: response = b'huh?' self.send_response(200, 'OK') self.send_header("Content-type", 'text/plain') self.send_header("Content-Length", len(response)) self.end_headers() try: self.wfile.write(response) finally: self.wfile.flush() r = urlparse(redirect_url) if r.port is None: port = 443 if redirect_url.startswith("https") else 80 else: port = r.port server_address = ('127.0.0.1', port) server = http.server.HTTPServer(server_address, Handler) server.redirect_path_prefix = r.path server.semaphore = threading.Event() server.path = None thread = threading.Thread(target=server.serve_forever) thread.start() webbrowser.open_new(authorize_url) signaled = server.semaphore.wait(timeout=30) if not signaled: log.debug("redirect server timed out") status = 408 path = '/?error=timeout&error_description=redirect+server+wait+timed+out' else: status = 200 path = server.path server.shutdown() thread.join() return status, path