class ThirdPartyService(object): TIMEOUT = (9.5, 30) RETRY_CONFIG = Retry(total=10, read=2, backoff_factor=random.uniform(1, 3), method_whitelist=frozenset([ 'HEAD', 'TRACE', 'GET', 'POST', 'PUT', 'OPTIONS', 'DELETE' ]), status_forcelist=[429, 500, 502, 503, 504]) ADAPTER = ThreadLocalWrapper( lambda: HTTPAdapter(max_retries=ThirdPartyService.RETRY_CONFIG)) def __init__(self, url=''): self.base_url = url self._session = ThreadLocalWrapper(self.build_session) @property def session(self): return self._session.get() def build_session(self): s = WrappedSession(self.base_url, timeout=self.TIMEOUT) s.mount('https://', self.ADAPTER.get()) s.headers.update({ "Content-Type": "application/octet-stream", "User-Agent": "dl-python/{}".format(__version__) }) return s
def __init__(self, url, token=None, auth=None, retries=None, session_class=None): if auth is None: auth = Auth() if token is not None: warn( "setting token at service level will be removed in future", DeprecationWarning, ) auth._token = token self.auth = auth self.base_url = url if retries is None: self._adapter = Service.ADAPTER else: self._adapter = ThreadLocalWrapper(lambda: HTTPAdapter(max_retries=retries)) if session_class is None: self._session_class = WrappedSession else: self._session_class = session_class # Sessions can't be shared across threads or processes because the underlying # SSL connection pool can't be shared. We create them thread-local to avoid # intractable exceptions when users naively share clients e.g. when using # multiprocessing. self._session = ThreadLocalWrapper(self.build_session)
class ThirdPartyService(object): CONNECT_TIMEOUT = 9.5 READ_TIMEOUT = 30 TIMEOUT = (CONNECT_TIMEOUT, READ_TIMEOUT) RETRY_CONFIG = Retry( total=10, read=2, backoff_factor=random.uniform(1, 3), method_whitelist=frozenset( [ HttpRequestMethod.HEAD, HttpRequestMethod.TRACE, HttpRequestMethod.GET, HttpRequestMethod.POST, HttpRequestMethod.PUT, HttpRequestMethod.OPTIONS, HttpRequestMethod.DELETE, ] ), status_forcelist=[ HttpStatusCode.TooManyRequests, HttpStatusCode.InternalServerError, HttpStatusCode.BadGateway, HttpStatusCode.ServiceUnavailable, HttpStatusCode.GatewayTimeout, ], ) ADAPTER = ThreadLocalWrapper( lambda: HTTPAdapter(max_retries=ThirdPartyService.RETRY_CONFIG) ) def __init__(self, url=""): self.base_url = url self._session = ThreadLocalWrapper(self.build_session) @property def session(self): return self._session.get() def build_session(self): s = WrappedSession(self.base_url, timeout=self.TIMEOUT) s.mount(HttpMountProtocol.HTTPS, self.ADAPTER.get()) s.headers.update( { HttpHeaderKeys.ContentType: HttpHeaderValues.ApplicationOctetStream, HttpHeaderKeys.UserAgent: "{}/{}".format( HttpHeaderValues.DlPython, __version__ ), } ) return s
def __init__(self, url="", session_class=None): self.base_url = url if session_class is not None: if not issubclass(session_class, Session): raise TypeError( "The session class must be a subclass of {}.".format( Session)) self._session_class = session_class self._session = ThreadLocalWrapper(self._build_session)
def __init__(self, url, token=None, auth=None): if auth is None: auth = Auth() if token is not None: warn("setting token at service level will be removed in future", DeprecationWarning) auth._token = token self.auth = auth self.base_url = url # Sessions can't be shared across threads or processes because the underlying # SSL connection pool can't be shared. We create them thread-local to avoid # intractable exceptions when users naively share clients e.g. when using # multiprocessing. self._session = ThreadLocalWrapper(self.build_session)
class ThreadLocalWrapperTest(unittest.TestCase): def setUp(self): self.wrapper = ThreadLocalWrapper( lambda: (os.getpid(), threading.current_thread().ident)) def _store_id(self): self.thread_id = self.wrapper.get() def _send_id(self, queue): queue.put(self.wrapper.get()) def test_wrapper(self): main_thread_id = self.wrapper.get() self.assertEqual(main_thread_id, self.wrapper.get()) thread = threading.Thread(target=self._store_id) thread.start() thread.join() self.assertNotEqual(main_thread_id, self.thread_id) queue = multiprocessing.Queue() process = multiprocessing.Process(target=self._send_id, args=(queue, )) process.start() process_id = queue.get() process.join() self.assertNotEqual(main_thread_id, process_id) self.assertNotEqual(self.thread_id, process_id) def test_wrapper_unused_in_main_process(self): queue = multiprocessing.Queue() process = multiprocessing.Process(target=self._send_id, args=(queue, )) process.start() process_id = queue.get() process.join() self.assertNotEqual(process_id, self.wrapper.get()) def test_fork_from_fork(self): # A gross edge case discovered by Clark: if a process is forked from a forked process # things will go awry if we hadn't initialized the internal threading.local's pid. def fork_another(queue): queue.put(self.wrapper.get()) process3 = multiprocessing.Process(target=self._send_id, args=(queue, )) process3.start() process3.join() process1_id = self.wrapper.get() queue = multiprocessing.Queue() process = multiprocessing.Process(target=fork_another, args=(queue, )) process.start() process2_id = queue.get() process3_id = queue.get() process.join() self.assertNotEqual(process1_id, process2_id) self.assertNotEqual(process2_id, process3_id) self.assertNotEqual(process1_id, process3_id)
class Service(object): TIMEOUT = (9.5, 30) RETRY_CONFIG = Retry(total=3, connect=2, read=2, status=2, backoff_factor=random.uniform(1, 3), method_whitelist=frozenset([ 'HEAD', 'TRACE', 'GET', 'POST', 'PUT', 'OPTIONS', 'DELETE' ]), status_forcelist=[500, 502, 503, 504]) # We share an adapter (one per thread/process) among all clients to take advantage # of the single underlying connection pool. ADAPTER = ThreadLocalWrapper( lambda: HTTPAdapter(max_retries=Service.RETRY_CONFIG)) def __init__(self, url, token=None, auth=None): if auth is None: auth = Auth() if token is not None: warn("setting token at service level will be removed in future", DeprecationWarning) auth._token = token self.auth = auth self.base_url = url # Sessions can't be shared across threads or processes because the underlying # SSL connection pool can't be shared. We create them thread-local to avoid # intractable exceptions when users naively share clients e.g. when using # multiprocessing. self._session = ThreadLocalWrapper(self.build_session) @property def token(self): return self.auth.token @token.setter def token(self, token): self.auth._token = token @property def session(self): session = self._session.get() if session.headers.get('Authorization') != self.token: session.headers['Authorization'] = self.token return session def build_session(self): s = WrappedSession(self.base_url, timeout=self.TIMEOUT) s.mount('https://', self.ADAPTER.get()) s.headers.update({ "Content-Type": "application/json", "User-Agent": "dl-python/{}".format(__version__), }) try: s.headers.update({ # https://github.com/easybuilders/easybuild/wiki/OS_flavor_name_version "X-Platform": platform.platform(), "X-Python": platform.python_version(), # https://stackoverflow.com/questions/47608532/how-to-detect-from-within-python-whether-packages-are-managed-with-conda "X-Conda": str( os.path.exists( os.path.join(sys.prefix, 'conda-meta', 'history'))), # https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook "X-Notebook": str('ipykernel' in sys.modules), }) except Exception: pass return s
def __init__(self, url=''): self.base_url = url self._session = ThreadLocalWrapper(self.build_session)
class Auth: RETRY_CONFIG = Retry(total=5, backoff_factor=random.uniform(1, 10), method_whitelist=frozenset(['GET', 'POST']), status_forcelist=[429, 500, 502, 503, 504]) ADAPTER = ThreadLocalWrapper(lambda: HTTPAdapter(max_retries=Auth.RETRY_CONFIG)) def __init__(self, domain="https://accounts.descarteslabs.com", scope=None, leeway=500, token_info_path=DEFAULT_TOKEN_INFO_PATH, client_id=None, client_secret=None, jwt_token=None, refresh_token=None): """ Helps retrieve JWT from a client id and refresh token for cli usage. :param domain: endpoint for auth0 :param scope: the JWT fields to be included :param leeway: JWT expiration leeway :param token_info_path: path to a JSON file optionally holding auth information :param client_id: JWT client id :param client_secret: JWT client secret :param jwt_token: the JWT token, if we already have one :param refresh_token: the refresh token """ self.token_info_path = token_info_path token_info = {} if self.token_info_path: try: with open(self.token_info_path) as fp: token_info = json.load(fp) except (IOError, ValueError): pass self.client_id = next( (x for x in ( client_id, os.environ.get('DESCARTESLABS_CLIENT_ID'), os.environ.get('CLIENT_ID'), token_info.get('client_id') ) if x is not None), None) self.client_secret = next( (x for x in ( client_secret, os.environ.get('DESCARTESLABS_CLIENT_SECRET'), os.environ.get('CLIENT_SECRET'), token_info.get('client_secret') ) if x is not None), None) self.refresh_token = next( (x for x in ( refresh_token, os.environ.get('DESCARTESLABS_REFRESH_TOKEN'), token_info.get('refresh_token') ) if x is not None), None) self._token = next( (x for x in ( jwt_token, os.environ.get('DESCARTESLABS_TOKEN'), token_info.get('JWT_TOKEN'), token_info.get('jwt_token') ) if x is not None), None) self.scope = next( (x for x in ( scope, token_info.get('scope') ) if x is not None), None) if token_info: # If the token was read from a path but environment variables were set, we may need # to reset the token. client_id_changed = token_info.get('client_id', None) != self.client_id client_secret_changed = token_info.get('client_secret', None) != self.client_secret refresh_token_changed = token_info.get('refresh_token', None) != self.refresh_token if client_id_changed or client_secret_changed or refresh_token_changed: self._token = None self._namespace = None self._session = ThreadLocalWrapper(self.build_session) self.domain = domain self.leeway = leeway @classmethod def from_environment_or_token_json(cls, **kwargs): """ Creates an Auth object from environment variables CLIENT_ID, CLIENT_SECRET, JWT_TOKEN if they are set, or else from a JSON file at the given path. :param domain: endpoint for auth0 :param scope: the JWT fields to be included :param leeway: JWT expiration leeway :param token_info_path: path to a JSON file optionally holding auth information """ return Auth(**kwargs) @property def token(self): if self._token is None: self._get_token() else: # might have token but could be close to expiration exp = self.payload.get('exp') if exp is not None: now = (datetime.datetime.utcnow() - datetime.datetime(1970, 1, 1)).total_seconds() if now + self.leeway > exp: try: self._get_token() except AuthError as e: # Unable to refresh, raise if now > exp if now > exp: raise e return self._token @property def payload(self): if self._token is None: self._get_token() if isinstance(self._token, six.text_type): token = self._token.encode('utf-8') else: token = self._token claims = token.split(b'.')[1] return json.loads(base64url_decode(claims).decode('utf-8')) @property def session(self): return self._session.get() def build_session(self): session = requests.Session() session.mount('https://', self.ADAPTER.get()) return session def _get_token(self, timeout=100): if self.client_id is None: raise AuthError("Could not find client_id") if self.client_secret is None and self.refresh_token is None: raise AuthError("Could not find client_secret or refresh token") if self.client_id in ["ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c"]: # TODO(justin) remove legacy handling # TODO (justin) insert deprecation warning if self.scope is None: scope = ['openid', 'name', 'groups', 'org', 'email'] else: scope = self.scope params = { "scope": " ".join(scope), "client_id": self.client_id, "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", "target": self.client_id, "api_type": "app", "refresh_token": self.refresh_token if self.refresh_token is not None else self.client_secret } else: params = { "client_id": self.client_id, "grant_type": "refresh_token", "refresh_token": self.refresh_token if self.refresh_token is not None else self.client_secret } if self.scope is not None: params["scope"] = " ".join(self.scope) r = self.session.post(self.domain + "/token", json=params, timeout=timeout) if r.status_code != 200: raise OauthError("%s: %s" % (r.status_code, r.text)) data = r.json() access_token = data.get('access_token') id_token = data.get('id_token') # TODO(justin) remove legacy id_token usage if access_token is not None: self._token = access_token elif id_token is not None: self._token = id_token else: raise OauthError("could not retrieve token") token_info = {} if self.token_info_path: try: with open(self.token_info_path) as fp: token_info = json.load(fp) except (IOError, ValueError): pass token_info['jwt_token'] = self._token if self.token_info_path: token_info_directory = os.path.dirname(self.token_info_path) makedirs_if_not_exists(token_info_directory) try: with open(self.token_info_path, 'w+') as fp: json.dump(token_info, fp) os.chmod(self.token_info_path, stat.S_IRUSR | stat.S_IWUSR) except IOError as e: warnings.warn('failed to save token: {}'.format(e)) @property def namespace(self): if self._namespace is None: self._namespace = sha1(self.payload['sub'].encode('utf-8')).hexdigest() return self._namespace
def _init_session(self): # Sessions can't be shared across threads or processes because the underlying # SSL connection pool can't be shared. We create them thread-local to avoid # intractable exceptions when users naively share clients e.g. when using # multiprocessing. self._session = ThreadLocalWrapper(self._build_session)
class Auth: """ Authentication client used to authenticate with all Descartes Labs service APIs. """ RETRY_CONFIG = Retry( total=5, backoff_factor=random.uniform(1, 10), method_whitelist=frozenset(["GET", "POST"]), status_forcelist=[429, 500, 502, 503, 504], ) ADAPTER = ThreadLocalWrapper( lambda: HTTPAdapter(max_retries=Auth.RETRY_CONFIG)) def __init__( self, domain="https://accounts.descarteslabs.com", scope=None, leeway=500, token_info_path=DEFAULT_TOKEN_INFO_PATH, client_id=None, client_secret=None, jwt_token=None, refresh_token=None, ): """ Helps retrieve JWT from a client id and refresh token for cli usage. :param str domain: The endpoint for auth0 :type scope: list(str) or None :param scope: The JWT fields to be included :param int leeway: JWT expiration leeway :type token_info_path: str or None :param token_info_path: Path to a JSON file optionally holding auth information :type client_id: str or None :param client_id: JWT client id :type client_secret: str or None :param client_secret: JWT client secret :type jwt_token: str or None :param jwt_token: The JWT token, if we already have one :type refresh_token: str or None :param refresh_token: The refresh token """ self.token_info_path = token_info_path token_info = {} if self.token_info_path: try: with open(self.token_info_path) as fp: token_info = json.load(fp) except (IOError, ValueError): pass self.client_id = next( (x for x in ( client_id, os.environ.get("DESCARTESLABS_CLIENT_ID"), os.environ.get("CLIENT_ID"), token_info.get("client_id"), ) if x is not None), None, ) self.client_secret = next( (x for x in ( client_secret, os.environ.get("DESCARTESLABS_CLIENT_SECRET"), os.environ.get("CLIENT_SECRET"), token_info.get("client_secret"), ) if x is not None), None, ) self.refresh_token = next( (x for x in ( refresh_token, os.environ.get("DESCARTESLABS_REFRESH_TOKEN"), token_info.get("refresh_token"), ) if x is not None), None, ) if self.client_secret != self.refresh_token: if self.client_secret is not None and self.refresh_token is not None: warnings.warn( "Authentication token mismatch: " "client_secret and refresh_token values must match for authentication to work correctly. " ) if self.refresh_token is not None: self.client_secret = self.refresh_token elif self.client_secret is not None: self.refresh_token = self.client_secret self._token = next( (x for x in ( jwt_token, os.environ.get("DESCARTESLABS_TOKEN"), token_info.get("JWT_TOKEN"), token_info.get("jwt_token"), ) if x is not None), None, ) self.scope = next( (x for x in (scope, token_info.get("scope")) if x is not None), None) if token_info: # If the token was read from a path but environment variables were set, we may need # to reset the token. client_id_changed = token_info.get("client_id", None) != self.client_id client_secret_changed = (token_info.get("client_secret", None) != self.client_secret) refresh_token_changed = (token_info.get("refresh_token", None) != self.refresh_token) if client_id_changed or client_secret_changed or refresh_token_changed: self._token = None self._namespace = None self._session = ThreadLocalWrapper(self.build_session) self.domain = domain self.leeway = leeway @classmethod def from_environment_or_token_json(cls, **kwargs): """ Creates an Auth object from environment variables CLIENT_ID, CLIENT_SECRET, JWT_TOKEN if they are set, or else from a JSON file at the given path. :param str domain: The endpoint for auth0 :type scope: list(str) or None :param scope: The JWT fields to be included :param int leeway: JWT expiration leeway :type token_info_path: str or None :param token_info_path: Path to a JSON file optionally holding auth information """ return Auth(**kwargs) @property def token(self): """ Gets the token. :rtype: str :return: The JWT token string. :raises ~descarteslabs.client.exceptions.AuthError: Raised when incomplete information has been provided. :raises ~descarteslabs.client.exceptions.OauthError: Raised when a token cannot be obtained or refreshed. """ if self._token is None: self._get_token() else: # might have token but could be close to expiration exp = self.payload.get("exp") if exp is not None: now = (datetime.datetime.utcnow() - datetime.datetime(1970, 1, 1)).total_seconds() if now + self.leeway > exp: try: self._get_token() except AuthError as e: # Unable to refresh, raise if now > exp if now > exp: raise e return self._token @property def payload(self): """ Gets the token payload. :rtype: dict :return: Dictionary containing the fields specified by scope, which may include: .. highlight:: none :: name: The name of the user. groups: Groups to which the user belongs. org: The organization to which the user belongs. email: The email address of the user. email_verified: True if the user's email has been verified. sub: The user identifier. exp: The expiration time of the token, in seconds since the start of the unix epoch. :raises ~descarteslabs.client.exceptions.AuthError: Raised when incomplete information has been provided. :raises ~descarteslabs.client.exceptions.OauthError: Raised when a token cannot be obtained or refreshed. """ if self._token is None: self._get_token() if isinstance(self._token, six.text_type): token = self._token.encode("utf-8") else: token = self._token claims = token.split(b".")[1] return json.loads(base64url_decode(claims).decode("utf-8")) @property def session(self): """ Gets the request session used to communicate with the OAuth server. :rtype: requests.Session :return: Session object """ return self._session.get() def build_session(self): session = requests.Session() session.mount("https://", self.ADAPTER.get()) return session def _get_token(self, timeout=100): if self.client_id is None: raise AuthError("Could not find client_id") if self.client_secret is None and self.refresh_token is None: raise AuthError("Could not find client_secret or refresh token") if self.client_id in ["ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c" ]: # TODO(justin) remove legacy handling # TODO (justin) insert deprecation warning if self.scope is None: scope = ["openid", "name", "groups", "org", "email"] else: scope = self.scope params = { "scope": " ".join(scope), "client_id": self.client_id, "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", "target": self.client_id, "api_type": "app", "refresh_token": self.refresh_token, } else: params = { "client_id": self.client_id, "grant_type": "refresh_token", "refresh_token": self.refresh_token, } if self.scope is not None: params["scope"] = " ".join(self.scope) r = self.session.post(self.domain + "/token", json=params, timeout=timeout) if r.status_code != 200: raise OauthError("%s: %s" % (r.status_code, r.text)) data = r.json() access_token = data.get("access_token") id_token = data.get( "id_token") # TODO(justin) remove legacy id_token usage if access_token is not None: self._token = access_token elif id_token is not None: self._token = id_token else: raise OauthError("could not retrieve token") token_info = {} if self.token_info_path: try: with open(self.token_info_path) as fp: token_info = json.load(fp) except (IOError, ValueError): pass token_info["jwt_token"] = self._token if self.token_info_path: token_info_directory = os.path.dirname(self.token_info_path) makedirs_if_not_exists(token_info_directory) try: with open(self.token_info_path, "w+") as fp: json.dump(token_info, fp) os.chmod(self.token_info_path, stat.S_IRUSR | stat.S_IWUSR) except IOError as e: warnings.warn("failed to save token: {}".format(e)) @property def namespace(self): """ Gets the user namespace. :rtype: str :return: The user namespace :raises ~descarteslabs.client.exceptions.AuthError: Raised when incomplete information has been provided. :raises ~descarteslabs.client.exceptions.OauthError: Raised when a token cannot be obtained or refreshed. """ if self._namespace is None: self._namespace = sha1( self.payload["sub"].encode("utf-8")).hexdigest() return self._namespace
class Service(object): # https://requests.readthedocs.io/en/master/user/advanced/#timeouts CONNECT_TIMEOUT = 9.5 READ_TIMEOUT = 30 TIMEOUT = (CONNECT_TIMEOUT, READ_TIMEOUT) RETRY_CONFIG = Retry( total=3, connect=2, read=2, status=2, backoff_factor=random.uniform(1, 3), method_whitelist=frozenset( [ HttpRequestMethod.HEAD, HttpRequestMethod.TRACE, HttpRequestMethod.GET, HttpRequestMethod.POST, HttpRequestMethod.PUT, HttpRequestMethod.PATCH, HttpRequestMethod.OPTIONS, HttpRequestMethod.DELETE, ] ), status_forcelist=[ HttpStatusCode.InternalServerError, HttpStatusCode.BadGateway, HttpStatusCode.ServiceUnavailable, HttpStatusCode.GatewayTimeout, ], ) # We share an adapter (one per thread/process) among all clients to take advantage # of the single underlying connection pool. ADAPTER = ThreadLocalWrapper(lambda: HTTPAdapter(max_retries=Service.RETRY_CONFIG)) def __init__(self, url, token=None, auth=None, retries=None, session_class=None): if auth is None: auth = Auth() if token is not None: warn( "setting token at service level will be removed in future", DeprecationWarning, ) auth._token = token self.auth = auth self.base_url = url if retries is None: self._adapter = Service.ADAPTER else: self._adapter = ThreadLocalWrapper(lambda: HTTPAdapter(max_retries=retries)) if session_class is None: self._session_class = WrappedSession else: self._session_class = session_class # Sessions can't be shared across threads or processes because the underlying # SSL connection pool can't be shared. We create them thread-local to avoid # intractable exceptions when users naively share clients e.g. when using # multiprocessing. self._session = ThreadLocalWrapper(self.build_session) @property def token(self): return self.auth.token @token.setter def token(self, token): self.auth._token = token @property def session(self): session = self._session.get() auth = add_bearer(self.token) if session.headers.get(HttpHeaderKeys.Authorization) != auth: session.headers[HttpHeaderKeys.Authorization] = auth return session def build_session(self): s = self._session_class(self.base_url, timeout=self.TIMEOUT) adapter = self._adapter.get() s.mount(HttpMountProtocol.HTTPS, adapter) s.mount(HttpMountProtocol.HTTP, adapter) s.headers.update( { HttpHeaderKeys.ContentType: HttpHeaderValues.ApplicationJson, HttpHeaderKeys.UserAgent: "{}/{}".format( HttpHeaderValues.DlPython, __version__ ), } ) try: s.headers.update( { # https://github.com/easybuilders/easybuild/wiki/OS_flavor_name_version HttpHeaderKeys.Platform: platform.platform(), HttpHeaderKeys.Python: platform.python_version(), # https://stackoverflow.com/questions/47608532/how-to-detect-from-within-python-whether-packages-are-managed-with-conda HttpHeaderKeys.Conda: str( os.path.exists( os.path.join(sys.prefix, "conda-meta", "history") ) ), # https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook HttpHeaderKeys.Notebook: str("ipykernel" in sys.modules), HttpHeaderKeys.ClientSession: uuid.uuid4().hex, } ) except Exception: pass return s
class ThreadLocalWrapperTest(unittest.TestCase): def setUp(self): self.wrapper = ThreadLocalWrapper( lambda: (os.getpid(), threading.current_thread().ident)) def _store_id(self): self.thread_id = self.wrapper.get() def _send_id(self, queue): queue.put(self.wrapper.get()) def test_thread_thread(self): main_thread_id = self.wrapper.get() assert main_thread_id == self.wrapper.get() thread = threading.Thread(target=self._store_id) thread.start() thread.join() assert main_thread_id != self.thread_id # Note on Windows: fork is not available so multiprocessing pickles the multiprocessing # function and arguments. ThreadLocalWrapper isn't picklable, so the following tests # can't work on Windows. But the problem it solves for multiprocessing also doesn't # exist there. @unittest.skipIf(sys.platform.startswith("win"), "forking not a concern on Windows") def test_wrapper_process(self): main_thread_id = self.wrapper.get() thread = threading.Thread(target=self._store_id) thread.start() thread.join() assert main_thread_id != self.thread_id queue = multiprocessing.Queue() process = multiprocessing.Process(target=self._send_id, args=(queue, )) process.start() process_id = queue.get() process.join() assert main_thread_id != process_id assert self.thread_id != process_id @unittest.skipIf(sys.platform.startswith("win"), "forking not a concern on Windows") def test_wrapper_unused_in_main_process(self): queue = multiprocessing.Queue() process = multiprocessing.Process(target=self._send_id, args=(queue, )) process.start() process_id = queue.get() process.join() assert process_id != self.wrapper.get() @unittest.skipIf(sys.platform.startswith("win"), "forking not a concern on Windows") def test_fork_from_fork(self): # A gross edge case discovered by Clark: if a process is forked from a forked process # things will go awry if we hadn't initialized the internal threading.local's pid. def fork_another(queue): queue.put(self.wrapper.get()) process3 = multiprocessing.Process(target=self._send_id, args=(queue, )) process3.start() process3.join() process1_id = self.wrapper.get() queue = multiprocessing.Queue() process = multiprocessing.Process(target=fork_another, args=(queue, )) process.start() process2_id = queue.get() process3_id = queue.get() process.join() assert process1_id != process2_id assert process2_id != process3_id assert process1_id != process3_id
def setUp(self): self.wrapper = ThreadLocalWrapper( lambda: (os.getpid(), threading.current_thread().ident))
def __init__(self, domain="https://accounts.descarteslabs.com", scope=None, leeway=500, token_info_path=DEFAULT_TOKEN_INFO_PATH, client_id=None, client_secret=None, jwt_token=None, refresh_token=None): """ Helps retrieve JWT from a client id and refresh token for cli usage. :param domain: endpoint for auth0 :param scope: the JWT fields to be included :param leeway: JWT expiration leeway :param token_info_path: path to a JSON file optionally holding auth information :param client_id: JWT client id :param client_secret: JWT client secret :param jwt_token: the JWT token, if we already have one :param refresh_token: the refresh token """ self.token_info_path = token_info_path token_info = {} if self.token_info_path: try: with open(self.token_info_path) as fp: token_info = json.load(fp) except (IOError, ValueError): pass self.client_id = next( (x for x in ( client_id, os.environ.get('DESCARTESLABS_CLIENT_ID'), os.environ.get('CLIENT_ID'), token_info.get('client_id') ) if x is not None), None) self.client_secret = next( (x for x in ( client_secret, os.environ.get('DESCARTESLABS_CLIENT_SECRET'), os.environ.get('CLIENT_SECRET'), token_info.get('client_secret') ) if x is not None), None) self.refresh_token = next( (x for x in ( refresh_token, os.environ.get('DESCARTESLABS_REFRESH_TOKEN'), token_info.get('refresh_token') ) if x is not None), None) self._token = next( (x for x in ( jwt_token, os.environ.get('DESCARTESLABS_TOKEN'), token_info.get('JWT_TOKEN'), token_info.get('jwt_token') ) if x is not None), None) self.scope = next( (x for x in ( scope, token_info.get('scope') ) if x is not None), None) if token_info: # If the token was read from a path but environment variables were set, we may need # to reset the token. client_id_changed = token_info.get('client_id', None) != self.client_id client_secret_changed = token_info.get('client_secret', None) != self.client_secret refresh_token_changed = token_info.get('refresh_token', None) != self.refresh_token if client_id_changed or client_secret_changed or refresh_token_changed: self._token = None self._namespace = None self._session = ThreadLocalWrapper(self.build_session) self.domain = domain self.leeway = leeway
class Service(object): """The default Descartes Labs HTTP Service used to communicate with its servers. This service has a default timeout and retry policy that retries HTTP requests depending on the timeout and HTTP status code that was returned. This is based on the `requests timeouts <https://requests.readthedocs.io/en/master/user/advanced/#timeouts>`_ and the `urllib3 retry object <https://urllib3.readthedocs.io/en/latest/reference/urllib3.util.html#urllib3.util.retry.Retry>`_. The default timeouts are set to 9.5 seconds for establishing a connection (slightly larger than a multiple of 3, which is the TCP default packet retransmission window), and 30 seconds for reading a response. The default retry logic retries up to 3 times total, a maximum of 2 for establishing a connection, 2 for reading a response, and 2 for unexpected HTTP status codes. The backoff_factor is a random number between 1 and 3, but will never be more than 2 minutes. The unexpected HTTP status codes that will be retried are ``500``, ``502``, ``503``, and ``504`` for any of the HTTP requests. Parameters ---------- url: str The URL prefix to use for communication with the Descartes Labs server. token: str, optional Deprecated. auth: Auth, optional A Descartes Labs :py:class:`~descarteslabs.client.auth.Auth` instance. If not provided, a default one will be instantiated. retries: int or urllib3.util.retry.Retry If a number, it's the number of retries that will be attempled. If a :py:class:`urllib3.util.retry.Retry` instance, it will determine the retry behavior. If not provided, the default retry policy as described above will be used. session_class: class The session class to use when instantiating the session. This must be a derived class from :py:class:`Session`. If not provided, the default session class is used. You can register a default session class with :py:meth:`Service.set_default_session_class`. Raises ------ TypeError If you try to use a session class that is not derived from :py:class:`Session`. """ # https://requests.readthedocs.io/en/master/user/advanced/#timeouts CONNECT_TIMEOUT = 9.5 READ_TIMEOUT = 30 TIMEOUT = (CONNECT_TIMEOUT, READ_TIMEOUT) RETRY_CONFIG = Retry( total=3, connect=2, read=2, status=2, backoff_factor=random.uniform(1, 3), method_whitelist=frozenset([ HttpRequestMethod.HEAD, HttpRequestMethod.TRACE, HttpRequestMethod.GET, HttpRequestMethod.POST, HttpRequestMethod.PUT, HttpRequestMethod.PATCH, HttpRequestMethod.OPTIONS, HttpRequestMethod.DELETE, ]), status_forcelist=[ HttpStatusCode.InternalServerError, HttpStatusCode.BadGateway, HttpStatusCode.ServiceUnavailable, HttpStatusCode.GatewayTimeout, ], ) # We share an adapter (one per thread/process) among all clients to take advantage # of the single underlying connection pool. ADAPTER = ThreadLocalWrapper( lambda: HTTPAdapter(max_retries=Service.RETRY_CONFIG)) _session_class = Session # List of attributes that will be included in state for pickling. # Subclasses can extend this attribute list. __attrs__ = ["auth", "base_url", "_session_class", "RETRY_CONFIG"] @classmethod def set_default_session_class(cls, session_class): """Set the default session class for :py:class:`Service`. The default session is used for any :py:class:`Service` that is instantiated without specifying the session class. Parameters ---------- session_class: class The session class to use when instantiating the session. This must be the class :py:class:`Session` itself or a derived class from :py:class:`Session`. """ if not issubclass(session_class, Session): raise TypeError( "The session class must be a subclass of {}.".format(Session)) cls._session_class = session_class @classmethod def get_default_session_class(cls): """Get the default session class for :py:class:`Service`. Returns ------- Session The default session class, which is :py:class:`Session` itself or a derived class from :py:class:`Session`. """ return cls._session_class def __init__(self, url, token=None, auth=None, retries=None, session_class=None): if auth is None: auth = Auth() if token is not None: warn( "setting token at service level will be removed in future", FutureWarning, ) auth._token = token self.auth = auth self.base_url = url if retries is None: self._adapter = self.ADAPTER else: self.RETRY_CONFIG = retries self._init_adapter() if session_class is not None: # Overwrite the default session class if not issubclass(session_class, Session): raise TypeError( "The session class must be a subclass of {}.".format( Session)) self._session_class = session_class self._init_session() def _init_adapter(self): self._adapter = ThreadLocalWrapper( lambda: HTTPAdapter(max_retries=self.RETRY_CONFIG)) def _init_session(self): # Sessions can't be shared across threads or processes because the underlying # SSL connection pool can't be shared. We create them thread-local to avoid # intractable exceptions when users naively share clients e.g. when using # multiprocessing. self._session = ThreadLocalWrapper(self._build_session) @property def token(self): """str: The bearer token used in the requests.""" return self.auth.token @token.setter def token(self, token): """str: Deprecated""" self.auth._token = token @property def session(self): """Session: The session instance used by this service.""" session = self._session.get() auth = add_bearer(self.token) if session.headers.get(HttpHeaderKeys.Authorization) != auth: session.headers[HttpHeaderKeys.Authorization] = auth return session def _build_session(self): session = self._session_class(self.base_url, timeout=self.TIMEOUT) session.initialize() adapter = self._adapter.get() session.mount(HttpMountProtocol.HTTPS, adapter) session.mount(HttpMountProtocol.HTTP, adapter) session.headers.update({ HttpHeaderKeys.ContentType: HttpHeaderValues.ApplicationJson, HttpHeaderKeys.UserAgent: "{}/{}".format(HttpHeaderValues.DlPython, __version__), }) try: session.headers.update({ # https://github.com/easybuilders/easybuild/wiki/OS_flavor_name_version HttpHeaderKeys.Platform: platform.platform(), HttpHeaderKeys.Python: platform.python_version(), # https://stackoverflow.com/questions/47608532/how-to-detect-from-within-python-whether-packages-are-managed-with-conda HttpHeaderKeys.Conda: str( os.path.exists( os.path.join(sys.prefix, "conda-meta", "history"))), # https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook HttpHeaderKeys.Notebook: str("ipykernel" in sys.modules), HttpHeaderKeys.ClientSession: uuid.uuid4().hex, }) except Exception: pass return session def __getstate__(self): return dict((attr, getattr(self, attr)) for attr in self.__attrs__) def __setstate__(self, state): for name, value in state.items(): setattr(self, name, value) self._init_adapter() self._init_session()
def _init_adapter(self): self._adapter = ThreadLocalWrapper( lambda: HTTPAdapter(max_retries=self.RETRY_CONFIG))
def __init__( self, domain="https://accounts.descarteslabs.com", scope=None, leeway=500, token_info_path=DEFAULT_TOKEN_INFO_PATH, client_id=None, client_secret=None, jwt_token=None, refresh_token=None, ): """ Helps retrieve JWT from a client id and refresh token for cli usage. :param str domain: The endpoint for auth0 :type scope: list(str) or None :param scope: The JWT fields to be included :param int leeway: JWT expiration leeway :type token_info_path: str or None :param token_info_path: Path to a JSON file optionally holding auth information :type client_id: str or None :param client_id: JWT client id :type client_secret: str or None :param client_secret: JWT client secret :type jwt_token: str or None :param jwt_token: The JWT token, if we already have one :type refresh_token: str or None :param refresh_token: The refresh token """ self.token_info_path = token_info_path token_info = {} if self.token_info_path: try: with open(self.token_info_path) as fp: token_info = json.load(fp) except (IOError, ValueError): pass self.client_id = next( (x for x in ( client_id, os.environ.get("DESCARTESLABS_CLIENT_ID"), os.environ.get("CLIENT_ID"), token_info.get("client_id"), ) if x is not None), None, ) self.client_secret = next( (x for x in ( client_secret, os.environ.get("DESCARTESLABS_CLIENT_SECRET"), os.environ.get("CLIENT_SECRET"), token_info.get("client_secret"), ) if x is not None), None, ) self.refresh_token = next( (x for x in ( refresh_token, os.environ.get("DESCARTESLABS_REFRESH_TOKEN"), token_info.get("refresh_token"), ) if x is not None), None, ) if self.client_secret != self.refresh_token: if self.client_secret is not None and self.refresh_token is not None: warnings.warn( "Authentication token mismatch: " "client_secret and refresh_token values must match for authentication to work correctly. " ) if self.refresh_token is not None: self.client_secret = self.refresh_token elif self.client_secret is not None: self.refresh_token = self.client_secret self._token = next( (x for x in ( jwt_token, os.environ.get("DESCARTESLABS_TOKEN"), token_info.get("JWT_TOKEN"), token_info.get("jwt_token"), ) if x is not None), None, ) self.scope = next( (x for x in (scope, token_info.get("scope")) if x is not None), None) if token_info: # If the token was read from a path but environment variables were set, we may need # to reset the token. client_id_changed = token_info.get("client_id", None) != self.client_id client_secret_changed = (token_info.get("client_secret", None) != self.client_secret) refresh_token_changed = (token_info.get("refresh_token", None) != self.refresh_token) if client_id_changed or client_secret_changed or refresh_token_changed: self._token = None self._namespace = None self._session = ThreadLocalWrapper(self.build_session) self.domain = domain self.leeway = leeway
class ThirdPartyService(object): """The default Descartes Labs HTTP Service used for 3rd party servers. This service has a default timeout and retry policy that retries HTTP requests depending on the timeout and HTTP status code that was returned. This is based on the `requests timeouts <https://requests.readthedocs.io/en/master/user/advanced/#timeouts>`_ and the `urllib3 retry object <https://urllib3.readthedocs.io/en/latest/reference/urllib3.util.html#urllib3.util.retry.Retry>`_. The default timeouts are set to 9.5 seconds for establishing a connection (slightly larger than a multiple of 3, which is the TCP default packet retransmission window), and 30 seconds for reading a response. The default retry logic retries up to 10 times total, a maximum of 2 for establishing a connection. The backoff_factor is a random number between 1 and 3, but will never be more than 2 minutes. The unexpected HTTP status codes that will be retried are ``429``, ``500``, ``502``, ``503``, and ``504`` for any of the HTTP requests. Parameters ---------- url: str The URL prefix to use for communication with the 3rd party server. session_class: class The session class to use when instantiating the session. This must be a derived class from :py:class:`Session`. If not provided, the default session class is used. You can register a default session class with :py:meth:`ThirdPartyService.set_default_session_class`. Raises ------ TypeError If you try to use a session class that is not derived from :py:class:`Session`. """ CONNECT_TIMEOUT = 9.5 READ_TIMEOUT = 30 TIMEOUT = (CONNECT_TIMEOUT, READ_TIMEOUT) RETRY_CONFIG = Retry( total=10, read=2, backoff_factor=random.uniform(1, 3), method_whitelist=frozenset([ HttpRequestMethod.HEAD, HttpRequestMethod.TRACE, HttpRequestMethod.GET, HttpRequestMethod.POST, HttpRequestMethod.PUT, HttpRequestMethod.OPTIONS, HttpRequestMethod.DELETE, ]), status_forcelist=[ HttpStatusCode.TooManyRequests, HttpStatusCode.InternalServerError, HttpStatusCode.BadGateway, HttpStatusCode.ServiceUnavailable, HttpStatusCode.GatewayTimeout, ], ) ADAPTER = ThreadLocalWrapper( lambda: HTTPAdapter(max_retries=ThirdPartyService.RETRY_CONFIG)) _session_class = Session @classmethod def set_default_session_class(cls, session_class=None): """Set the default session class for :py:class:`ThirdPartyService`. The default session is used for any :py:meth:`ThirdPartyService` that is instantiated without specifying the session class. Parameters ---------- session_class: class The session class to use when instantiating the session. This must be the class :py:class:`Session` itself or a derived class from :py:class:`Session`. """ if not issubclass(session_class, Session): raise TypeError( "The session class must be a subclass of {}.".format(Session)) cls._session_class = session_class @classmethod def get_default_session_class(cls): """Get the default session class for the :py:class:`ThirdPartyService`. Returns ------- Session The default session class, which is :py:class:`Session` itself or a derived class from :py:class:`Session`. """ return cls._session_class def __init__(self, url="", session_class=None): self.base_url = url if session_class is not None: if not issubclass(session_class, Session): raise TypeError( "The session class must be a subclass of {}.".format( Session)) self._session_class = session_class self._session = ThreadLocalWrapper(self._build_session) @property def session(self): return self._session.get() def _build_session(self): session = self._session_class(self.base_url, timeout=self.TIMEOUT) session.initialize() session.mount(HttpMountProtocol.HTTPS, self.ADAPTER.get()) session.headers.update({ HttpHeaderKeys.ContentType: HttpHeaderValues.ApplicationOctetStream, HttpHeaderKeys.UserAgent: "{}/{}".format(HttpHeaderValues.DlPython, __version__), }) return session
class Service(object): TIMEOUT = (9.5, 30) RETRY_CONFIG = Retry(total=5, read=2, backoff_factor=random.uniform(1, 3), method_whitelist=frozenset([ 'HEAD', 'TRACE', 'GET', 'POST', 'PUT', 'OPTIONS', 'DELETE' ]), status_forcelist=[500, 502, 503, 504]) # We share an adapter (one per thread/process) among all clients to take advantage # of the single underlying connection pool. ADAPTER = ThreadLocalWrapper( lambda: HTTPAdapter(max_retries=Service.RETRY_CONFIG)) def __init__(self, url, token=None, auth=None): if auth is None: auth = Auth() if token is not None: warn("setting token at service level will be removed in future", DeprecationWarning) auth._token = token self.auth = auth self.base_url = url # Sessions can't be shared across threads or processes because the underlying # SSL connection pool can't be shared. We create them thread-local to avoid # intractable exceptions when users naively share clients e.g. when using # multiprocessing. self._session = ThreadLocalWrapper(self.build_session) @property def token(self): return self.auth.token @token.setter def token(self, token): self.auth._token = token @property def session(self): session = self._session.get() if session.headers.get('Authorization') != self.token: session.headers['Authorization'] = self.token return session def build_session(self): s = WrappedSession(self.base_url, timeout=self.TIMEOUT) s.mount('https://', self.ADAPTER.get()) s.headers.update({ "Content-Type": "application/json", "User-Agent": "dl-python/{}".format(__version__) }) return s