class ThirdPartyService(object):
    TIMEOUT = (9.5, 30)

    RETRY_CONFIG = Retry(total=10,
                         read=2,
                         backoff_factor=random.uniform(1, 3),
                         method_whitelist=frozenset([
                             'HEAD', 'TRACE', 'GET', 'POST', 'PUT', 'OPTIONS',
                             'DELETE'
                         ]),
                         status_forcelist=[429, 500, 502, 503, 504])

    ADAPTER = ThreadLocalWrapper(
        lambda: HTTPAdapter(max_retries=ThirdPartyService.RETRY_CONFIG))

    def __init__(self, url=''):
        self.base_url = url

        self._session = ThreadLocalWrapper(self.build_session)

    @property
    def session(self):
        return self._session.get()

    def build_session(self):
        s = WrappedSession(self.base_url, timeout=self.TIMEOUT)
        s.mount('https://', self.ADAPTER.get())

        s.headers.update({
            "Content-Type": "application/octet-stream",
            "User-Agent": "dl-python/{}".format(__version__)
        })

        return s
Esempio n. 2
0
    def __init__(self, url, token=None, auth=None, retries=None, session_class=None):
        if auth is None:
            auth = Auth()

        if token is not None:
            warn(
                "setting token at service level will be removed in future",
                DeprecationWarning,
            )
            auth._token = token

        self.auth = auth
        self.base_url = url

        if retries is None:
            self._adapter = Service.ADAPTER
        else:
            self._adapter = ThreadLocalWrapper(lambda: HTTPAdapter(max_retries=retries))

        if session_class is None:
            self._session_class = WrappedSession
        else:
            self._session_class = session_class

        # Sessions can't be shared across threads or processes because the underlying
        # SSL connection pool can't be shared. We create them thread-local to avoid
        # intractable exceptions when users naively share clients e.g. when using
        # multiprocessing.
        self._session = ThreadLocalWrapper(self.build_session)
Esempio n. 3
0
class ThirdPartyService(object):
    CONNECT_TIMEOUT = 9.5
    READ_TIMEOUT = 30
    TIMEOUT = (CONNECT_TIMEOUT, READ_TIMEOUT)

    RETRY_CONFIG = Retry(
        total=10,
        read=2,
        backoff_factor=random.uniform(1, 3),
        method_whitelist=frozenset(
            [
                HttpRequestMethod.HEAD,
                HttpRequestMethod.TRACE,
                HttpRequestMethod.GET,
                HttpRequestMethod.POST,
                HttpRequestMethod.PUT,
                HttpRequestMethod.OPTIONS,
                HttpRequestMethod.DELETE,
            ]
        ),
        status_forcelist=[
            HttpStatusCode.TooManyRequests,
            HttpStatusCode.InternalServerError,
            HttpStatusCode.BadGateway,
            HttpStatusCode.ServiceUnavailable,
            HttpStatusCode.GatewayTimeout,
        ],
    )

    ADAPTER = ThreadLocalWrapper(
        lambda: HTTPAdapter(max_retries=ThirdPartyService.RETRY_CONFIG)
    )

    def __init__(self, url=""):
        self.base_url = url
        self._session = ThreadLocalWrapper(self.build_session)

    @property
    def session(self):
        return self._session.get()

    def build_session(self):
        s = WrappedSession(self.base_url, timeout=self.TIMEOUT)
        s.mount(HttpMountProtocol.HTTPS, self.ADAPTER.get())

        s.headers.update(
            {
                HttpHeaderKeys.ContentType: HttpHeaderValues.ApplicationOctetStream,
                HttpHeaderKeys.UserAgent: "{}/{}".format(
                    HttpHeaderValues.DlPython, __version__
                ),
            }
        )

        return s
Esempio n. 4
0
    def __init__(self, url="", session_class=None):
        self.base_url = url

        if session_class is not None:
            if not issubclass(session_class, Session):
                raise TypeError(
                    "The session class must be a subclass of {}.".format(
                        Session))

            self._session_class = session_class

        self._session = ThreadLocalWrapper(self._build_session)
    def __init__(self, url, token=None, auth=None):
        if auth is None:
            auth = Auth()

        if token is not None:
            warn("setting token at service level will be removed in future",
                 DeprecationWarning)
            auth._token = token

        self.auth = auth

        self.base_url = url

        # Sessions can't be shared across threads or processes because the underlying
        # SSL connection pool can't be shared. We create them thread-local to avoid
        # intractable exceptions when users naively share clients e.g. when using
        # multiprocessing.
        self._session = ThreadLocalWrapper(self.build_session)
Esempio n. 6
0
class ThreadLocalWrapperTest(unittest.TestCase):
    def setUp(self):
        self.wrapper = ThreadLocalWrapper(
            lambda: (os.getpid(), threading.current_thread().ident))

    def _store_id(self):
        self.thread_id = self.wrapper.get()

    def _send_id(self, queue):
        queue.put(self.wrapper.get())

    def test_wrapper(self):
        main_thread_id = self.wrapper.get()
        self.assertEqual(main_thread_id, self.wrapper.get())

        thread = threading.Thread(target=self._store_id)
        thread.start()
        thread.join()
        self.assertNotEqual(main_thread_id, self.thread_id)

        queue = multiprocessing.Queue()
        process = multiprocessing.Process(target=self._send_id, args=(queue, ))
        process.start()
        process_id = queue.get()
        process.join()
        self.assertNotEqual(main_thread_id, process_id)
        self.assertNotEqual(self.thread_id, process_id)

    def test_wrapper_unused_in_main_process(self):
        queue = multiprocessing.Queue()
        process = multiprocessing.Process(target=self._send_id, args=(queue, ))
        process.start()
        process_id = queue.get()
        process.join()
        self.assertNotEqual(process_id, self.wrapper.get())

    def test_fork_from_fork(self):
        # A gross edge case discovered by Clark: if a process is forked from a forked process
        # things will go awry if we hadn't initialized the internal threading.local's pid.
        def fork_another(queue):
            queue.put(self.wrapper.get())
            process3 = multiprocessing.Process(target=self._send_id,
                                               args=(queue, ))
            process3.start()
            process3.join()

        process1_id = self.wrapper.get()
        queue = multiprocessing.Queue()
        process = multiprocessing.Process(target=fork_another, args=(queue, ))
        process.start()
        process2_id = queue.get()
        process3_id = queue.get()
        process.join()
        self.assertNotEqual(process1_id, process2_id)
        self.assertNotEqual(process2_id, process3_id)
        self.assertNotEqual(process1_id, process3_id)
class Service(object):
    TIMEOUT = (9.5, 30)

    RETRY_CONFIG = Retry(total=3,
                         connect=2,
                         read=2,
                         status=2,
                         backoff_factor=random.uniform(1, 3),
                         method_whitelist=frozenset([
                             'HEAD', 'TRACE', 'GET', 'POST', 'PUT', 'OPTIONS',
                             'DELETE'
                         ]),
                         status_forcelist=[500, 502, 503, 504])

    # We share an adapter (one per thread/process) among all clients to take advantage
    # of the single underlying connection pool.
    ADAPTER = ThreadLocalWrapper(
        lambda: HTTPAdapter(max_retries=Service.RETRY_CONFIG))

    def __init__(self, url, token=None, auth=None):
        if auth is None:
            auth = Auth()

        if token is not None:
            warn("setting token at service level will be removed in future",
                 DeprecationWarning)
            auth._token = token

        self.auth = auth

        self.base_url = url

        # Sessions can't be shared across threads or processes because the underlying
        # SSL connection pool can't be shared. We create them thread-local to avoid
        # intractable exceptions when users naively share clients e.g. when using
        # multiprocessing.
        self._session = ThreadLocalWrapper(self.build_session)

    @property
    def token(self):
        return self.auth.token

    @token.setter
    def token(self, token):
        self.auth._token = token

    @property
    def session(self):
        session = self._session.get()
        if session.headers.get('Authorization') != self.token:
            session.headers['Authorization'] = self.token

        return session

    def build_session(self):
        s = WrappedSession(self.base_url, timeout=self.TIMEOUT)
        s.mount('https://', self.ADAPTER.get())

        s.headers.update({
            "Content-Type": "application/json",
            "User-Agent": "dl-python/{}".format(__version__),
        })

        try:
            s.headers.update({
                # https://github.com/easybuilders/easybuild/wiki/OS_flavor_name_version
                "X-Platform":
                platform.platform(),
                "X-Python":
                platform.python_version(),
                # https://stackoverflow.com/questions/47608532/how-to-detect-from-within-python-whether-packages-are-managed-with-conda
                "X-Conda":
                str(
                    os.path.exists(
                        os.path.join(sys.prefix, 'conda-meta', 'history'))),
                # https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook
                "X-Notebook":
                str('ipykernel' in sys.modules),
            })
        except Exception:
            pass

        return s
    def __init__(self, url=''):
        self.base_url = url

        self._session = ThreadLocalWrapper(self.build_session)
Esempio n. 9
0
class Auth:

    RETRY_CONFIG = Retry(total=5,
                         backoff_factor=random.uniform(1, 10),
                         method_whitelist=frozenset(['GET', 'POST']),
                         status_forcelist=[429, 500, 502, 503, 504])

    ADAPTER = ThreadLocalWrapper(lambda: HTTPAdapter(max_retries=Auth.RETRY_CONFIG))

    def __init__(self, domain="https://accounts.descarteslabs.com",
                 scope=None, leeway=500, token_info_path=DEFAULT_TOKEN_INFO_PATH,
                 client_id=None, client_secret=None, jwt_token=None, refresh_token=None):
        """
        Helps retrieve JWT from a client id and refresh token for cli usage.
        :param domain: endpoint for auth0
        :param scope: the JWT fields to be included
        :param leeway: JWT expiration leeway
        :param token_info_path: path to a JSON file optionally holding auth information
        :param client_id: JWT client id
        :param client_secret: JWT client secret
        :param jwt_token: the JWT token, if we already have one
        :param refresh_token: the refresh token
        """
        self.token_info_path = token_info_path

        token_info = {}
        if self.token_info_path:
            try:
                with open(self.token_info_path) as fp:
                    token_info = json.load(fp)
            except (IOError, ValueError):
                pass

        self.client_id = next(
            (x for x in (
                client_id,
                os.environ.get('DESCARTESLABS_CLIENT_ID'),
                os.environ.get('CLIENT_ID'),
                token_info.get('client_id')
            ) if x is not None), None)

        self.client_secret = next(
            (x for x in (
                client_secret,
                os.environ.get('DESCARTESLABS_CLIENT_SECRET'),
                os.environ.get('CLIENT_SECRET'),
                token_info.get('client_secret')
            ) if x is not None), None)

        self.refresh_token = next(
            (x for x in (
                refresh_token,
                os.environ.get('DESCARTESLABS_REFRESH_TOKEN'),
                token_info.get('refresh_token')
            ) if x is not None), None)

        self._token = next(
            (x for x in (
                jwt_token,
                os.environ.get('DESCARTESLABS_TOKEN'),
                token_info.get('JWT_TOKEN'),
                token_info.get('jwt_token')
            ) if x is not None), None)

        self.scope = next(
            (x for x in (
                scope,
                token_info.get('scope')
            ) if x is not None), None)

        if token_info:
            # If the token was read from a path but environment variables were set, we may need
            # to reset the token.
            client_id_changed = token_info.get('client_id', None) != self.client_id
            client_secret_changed = token_info.get('client_secret', None) != self.client_secret
            refresh_token_changed = token_info.get('refresh_token', None) != self.refresh_token

            if client_id_changed or client_secret_changed or refresh_token_changed:
                self._token = None

        self._namespace = None
        self._session = ThreadLocalWrapper(self.build_session)
        self.domain = domain
        self.leeway = leeway

    @classmethod
    def from_environment_or_token_json(cls, **kwargs):
        """
        Creates an Auth object from environment variables CLIENT_ID, CLIENT_SECRET,
        JWT_TOKEN if they are set, or else from a JSON file at the given path.
        :param domain: endpoint for auth0
        :param scope: the JWT fields to be included
        :param leeway: JWT expiration leeway
        :param token_info_path: path to a JSON file optionally holding auth information
        """
        return Auth(**kwargs)

    @property
    def token(self):
        if self._token is None:
            self._get_token()
        else:  # might have token but could be close to expiration
            exp = self.payload.get('exp')

            if exp is not None:
                now = (datetime.datetime.utcnow() - datetime.datetime(1970, 1, 1)).total_seconds()
                if now + self.leeway > exp:
                    try:
                        self._get_token()
                    except AuthError as e:
                        # Unable to refresh, raise if now > exp
                        if now > exp:
                            raise e

        return self._token

    @property
    def payload(self):
        if self._token is None:
            self._get_token()

        if isinstance(self._token, six.text_type):
            token = self._token.encode('utf-8')
        else:
            token = self._token

        claims = token.split(b'.')[1]
        return json.loads(base64url_decode(claims).decode('utf-8'))

    @property
    def session(self):
        return self._session.get()

    def build_session(self):
        session = requests.Session()
        session.mount('https://', self.ADAPTER.get())
        return session

    def _get_token(self, timeout=100):
        if self.client_id is None:
            raise AuthError("Could not find client_id")

        if self.client_secret is None and self.refresh_token is None:
            raise AuthError("Could not find client_secret or refresh token")

        if self.client_id in ["ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c"]:  # TODO(justin) remove legacy handling
            # TODO (justin) insert deprecation warning
            if self.scope is None:
                scope = ['openid', 'name', 'groups', 'org', 'email']
            else:
                scope = self.scope
            params = {
                "scope": " ".join(scope),
                "client_id": self.client_id,
                "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
                "target": self.client_id,
                "api_type": "app",
                "refresh_token": self.refresh_token if self.refresh_token is not None else self.client_secret
            }
        else:
            params = {
                "client_id": self.client_id,
                "grant_type": "refresh_token",
                "refresh_token": self.refresh_token if self.refresh_token is not None else self.client_secret
            }

            if self.scope is not None:
                params["scope"] = " ".join(self.scope)

        r = self.session.post(self.domain + "/token", json=params, timeout=timeout)

        if r.status_code != 200:
            raise OauthError("%s: %s" % (r.status_code, r.text))

        data = r.json()
        access_token = data.get('access_token')
        id_token = data.get('id_token')  # TODO(justin) remove legacy id_token usage

        if access_token is not None:
            self._token = access_token
        elif id_token is not None:
            self._token = id_token
        else:
            raise OauthError("could not retrieve token")
        token_info = {}

        if self.token_info_path:
            try:
                with open(self.token_info_path) as fp:
                    token_info = json.load(fp)
            except (IOError, ValueError):
                pass

        token_info['jwt_token'] = self._token

        if self.token_info_path:
            token_info_directory = os.path.dirname(self.token_info_path)
            makedirs_if_not_exists(token_info_directory)

            try:
                with open(self.token_info_path, 'w+') as fp:
                    json.dump(token_info, fp)

                os.chmod(self.token_info_path, stat.S_IRUSR | stat.S_IWUSR)
            except IOError as e:
                warnings.warn('failed to save token: {}'.format(e))

    @property
    def namespace(self):
        if self._namespace is None:
            self._namespace = sha1(self.payload['sub'].encode('utf-8')).hexdigest()
        return self._namespace
Esempio n. 10
0
 def _init_session(self):
     # Sessions can't be shared across threads or processes because the underlying
     # SSL connection pool can't be shared. We create them thread-local to avoid
     # intractable exceptions when users naively share clients e.g. when using
     # multiprocessing.
     self._session = ThreadLocalWrapper(self._build_session)
Esempio n. 11
0
class Auth:
    """
    Authentication client used to authenticate with all Descartes Labs service APIs.
    """

    RETRY_CONFIG = Retry(
        total=5,
        backoff_factor=random.uniform(1, 10),
        method_whitelist=frozenset(["GET", "POST"]),
        status_forcelist=[429, 500, 502, 503, 504],
    )

    ADAPTER = ThreadLocalWrapper(
        lambda: HTTPAdapter(max_retries=Auth.RETRY_CONFIG))

    def __init__(
        self,
        domain="https://accounts.descarteslabs.com",
        scope=None,
        leeway=500,
        token_info_path=DEFAULT_TOKEN_INFO_PATH,
        client_id=None,
        client_secret=None,
        jwt_token=None,
        refresh_token=None,
    ):
        """
        Helps retrieve JWT from a client id and refresh token for cli usage.

        :param str domain: The endpoint for auth0
        :type scope: list(str) or None
        :param scope: The JWT fields to be included
        :param int leeway: JWT expiration leeway
        :type token_info_path: str or None
        :param token_info_path: Path to a JSON file optionally holding auth information
        :type client_id: str or None
        :param client_id: JWT client id
        :type client_secret: str or None
        :param client_secret: JWT client secret
        :type jwt_token: str or None
        :param jwt_token: The JWT token, if we already have one
        :type refresh_token: str or None
        :param refresh_token: The refresh token
        """
        self.token_info_path = token_info_path

        token_info = {}
        if self.token_info_path:
            try:
                with open(self.token_info_path) as fp:
                    token_info = json.load(fp)
            except (IOError, ValueError):
                pass

        self.client_id = next(
            (x for x in (
                client_id,
                os.environ.get("DESCARTESLABS_CLIENT_ID"),
                os.environ.get("CLIENT_ID"),
                token_info.get("client_id"),
            ) if x is not None),
            None,
        )

        self.client_secret = next(
            (x for x in (
                client_secret,
                os.environ.get("DESCARTESLABS_CLIENT_SECRET"),
                os.environ.get("CLIENT_SECRET"),
                token_info.get("client_secret"),
            ) if x is not None),
            None,
        )

        self.refresh_token = next(
            (x for x in (
                refresh_token,
                os.environ.get("DESCARTESLABS_REFRESH_TOKEN"),
                token_info.get("refresh_token"),
            ) if x is not None),
            None,
        )

        if self.client_secret != self.refresh_token:
            if self.client_secret is not None and self.refresh_token is not None:
                warnings.warn(
                    "Authentication token mismatch: "
                    "client_secret and refresh_token values must match for authentication to work correctly. "
                )

            if self.refresh_token is not None:
                self.client_secret = self.refresh_token
            elif self.client_secret is not None:
                self.refresh_token = self.client_secret

        self._token = next(
            (x for x in (
                jwt_token,
                os.environ.get("DESCARTESLABS_TOKEN"),
                token_info.get("JWT_TOKEN"),
                token_info.get("jwt_token"),
            ) if x is not None),
            None,
        )

        self.scope = next(
            (x for x in (scope, token_info.get("scope")) if x is not None),
            None)

        if token_info:
            # If the token was read from a path but environment variables were set, we may need
            # to reset the token.
            client_id_changed = token_info.get("client_id",
                                               None) != self.client_id
            client_secret_changed = (token_info.get("client_secret", None) !=
                                     self.client_secret)
            refresh_token_changed = (token_info.get("refresh_token", None) !=
                                     self.refresh_token)

            if client_id_changed or client_secret_changed or refresh_token_changed:
                self._token = None

        self._namespace = None
        self._session = ThreadLocalWrapper(self.build_session)
        self.domain = domain
        self.leeway = leeway

    @classmethod
    def from_environment_or_token_json(cls, **kwargs):
        """
        Creates an Auth object from environment variables CLIENT_ID,
        CLIENT_SECRET, JWT_TOKEN if they are set, or else from a JSON
        file at the given path.

        :param str domain: The endpoint for auth0
        :type scope: list(str) or None
        :param scope: The JWT fields to be included
        :param int leeway: JWT expiration leeway
        :type token_info_path: str or None
        :param token_info_path: Path to a JSON file optionally holding auth information
        """
        return Auth(**kwargs)

    @property
    def token(self):
        """
        Gets the token.

        :rtype: str
        :return: The JWT token string.

        :raises ~descarteslabs.client.exceptions.AuthError: Raised when
            incomplete information has been provided.
        :raises ~descarteslabs.client.exceptions.OauthError: Raised when
            a token cannot be obtained or refreshed.
        """
        if self._token is None:
            self._get_token()
        else:  # might have token but could be close to expiration
            exp = self.payload.get("exp")

            if exp is not None:
                now = (datetime.datetime.utcnow() -
                       datetime.datetime(1970, 1, 1)).total_seconds()
                if now + self.leeway > exp:
                    try:
                        self._get_token()
                    except AuthError as e:
                        # Unable to refresh, raise if now > exp
                        if now > exp:
                            raise e

        return self._token

    @property
    def payload(self):
        """
        Gets the token payload.

        :rtype: dict
        :return: Dictionary containing the fields specified by scope, which may include:

            .. highlight:: none

            ::

                name:           The name of the user.
                groups:         Groups to which the user belongs.
                org:            The organization to which the user belongs.
                email:          The email address of the user.
                email_verified: True if the user's email has been verified.
                sub:            The user identifier.
                exp:            The expiration time of the token, in seconds since
                                the start of the unix epoch.

        :raises ~descarteslabs.client.exceptions.AuthError: Raised when
            incomplete information has been provided.
        :raises ~descarteslabs.client.exceptions.OauthError: Raised when
            a token cannot be obtained or refreshed.
        """
        if self._token is None:
            self._get_token()

        if isinstance(self._token, six.text_type):
            token = self._token.encode("utf-8")
        else:
            token = self._token

        claims = token.split(b".")[1]
        return json.loads(base64url_decode(claims).decode("utf-8"))

    @property
    def session(self):
        """
        Gets the request session used to communicate with the OAuth server.

        :rtype: requests.Session
        :return: Session object
        """
        return self._session.get()

    def build_session(self):
        session = requests.Session()
        session.mount("https://", self.ADAPTER.get())
        return session

    def _get_token(self, timeout=100):
        if self.client_id is None:
            raise AuthError("Could not find client_id")

        if self.client_secret is None and self.refresh_token is None:
            raise AuthError("Could not find client_secret or refresh token")

        if self.client_id in ["ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c"
                              ]:  # TODO(justin) remove legacy handling
            # TODO (justin) insert deprecation warning
            if self.scope is None:
                scope = ["openid", "name", "groups", "org", "email"]
            else:
                scope = self.scope
            params = {
                "scope": " ".join(scope),
                "client_id": self.client_id,
                "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
                "target": self.client_id,
                "api_type": "app",
                "refresh_token": self.refresh_token,
            }
        else:
            params = {
                "client_id": self.client_id,
                "grant_type": "refresh_token",
                "refresh_token": self.refresh_token,
            }

            if self.scope is not None:
                params["scope"] = " ".join(self.scope)

        r = self.session.post(self.domain + "/token",
                              json=params,
                              timeout=timeout)

        if r.status_code != 200:
            raise OauthError("%s: %s" % (r.status_code, r.text))

        data = r.json()
        access_token = data.get("access_token")
        id_token = data.get(
            "id_token")  # TODO(justin) remove legacy id_token usage

        if access_token is not None:
            self._token = access_token
        elif id_token is not None:
            self._token = id_token
        else:
            raise OauthError("could not retrieve token")
        token_info = {}

        if self.token_info_path:
            try:
                with open(self.token_info_path) as fp:
                    token_info = json.load(fp)
            except (IOError, ValueError):
                pass

        token_info["jwt_token"] = self._token

        if self.token_info_path:
            token_info_directory = os.path.dirname(self.token_info_path)
            makedirs_if_not_exists(token_info_directory)

            try:
                with open(self.token_info_path, "w+") as fp:
                    json.dump(token_info, fp)

                os.chmod(self.token_info_path, stat.S_IRUSR | stat.S_IWUSR)
            except IOError as e:
                warnings.warn("failed to save token: {}".format(e))

    @property
    def namespace(self):
        """
        Gets the user namespace.

        :rtype: str
        :return: The user namespace

        :raises ~descarteslabs.client.exceptions.AuthError: Raised when
            incomplete information has been provided.
        :raises ~descarteslabs.client.exceptions.OauthError: Raised when
            a token cannot be obtained or refreshed.
        """
        if self._namespace is None:
            self._namespace = sha1(
                self.payload["sub"].encode("utf-8")).hexdigest()
        return self._namespace
Esempio n. 12
0
class Service(object):
    # https://requests.readthedocs.io/en/master/user/advanced/#timeouts
    CONNECT_TIMEOUT = 9.5
    READ_TIMEOUT = 30

    TIMEOUT = (CONNECT_TIMEOUT, READ_TIMEOUT)

    RETRY_CONFIG = Retry(
        total=3,
        connect=2,
        read=2,
        status=2,
        backoff_factor=random.uniform(1, 3),
        method_whitelist=frozenset(
            [
                HttpRequestMethod.HEAD,
                HttpRequestMethod.TRACE,
                HttpRequestMethod.GET,
                HttpRequestMethod.POST,
                HttpRequestMethod.PUT,
                HttpRequestMethod.PATCH,
                HttpRequestMethod.OPTIONS,
                HttpRequestMethod.DELETE,
            ]
        ),
        status_forcelist=[
            HttpStatusCode.InternalServerError,
            HttpStatusCode.BadGateway,
            HttpStatusCode.ServiceUnavailable,
            HttpStatusCode.GatewayTimeout,
        ],
    )

    # We share an adapter (one per thread/process) among all clients to take advantage
    # of the single underlying connection pool.
    ADAPTER = ThreadLocalWrapper(lambda: HTTPAdapter(max_retries=Service.RETRY_CONFIG))

    def __init__(self, url, token=None, auth=None, retries=None, session_class=None):
        if auth is None:
            auth = Auth()

        if token is not None:
            warn(
                "setting token at service level will be removed in future",
                DeprecationWarning,
            )
            auth._token = token

        self.auth = auth
        self.base_url = url

        if retries is None:
            self._adapter = Service.ADAPTER
        else:
            self._adapter = ThreadLocalWrapper(lambda: HTTPAdapter(max_retries=retries))

        if session_class is None:
            self._session_class = WrappedSession
        else:
            self._session_class = session_class

        # Sessions can't be shared across threads or processes because the underlying
        # SSL connection pool can't be shared. We create them thread-local to avoid
        # intractable exceptions when users naively share clients e.g. when using
        # multiprocessing.
        self._session = ThreadLocalWrapper(self.build_session)

    @property
    def token(self):
        return self.auth.token

    @token.setter
    def token(self, token):
        self.auth._token = token

    @property
    def session(self):
        session = self._session.get()
        auth = add_bearer(self.token)
        if session.headers.get(HttpHeaderKeys.Authorization) != auth:
            session.headers[HttpHeaderKeys.Authorization] = auth

        return session

    def build_session(self):
        s = self._session_class(self.base_url, timeout=self.TIMEOUT)
        adapter = self._adapter.get()
        s.mount(HttpMountProtocol.HTTPS, adapter)
        s.mount(HttpMountProtocol.HTTP, adapter)

        s.headers.update(
            {
                HttpHeaderKeys.ContentType: HttpHeaderValues.ApplicationJson,
                HttpHeaderKeys.UserAgent: "{}/{}".format(
                    HttpHeaderValues.DlPython, __version__
                ),
            }
        )

        try:
            s.headers.update(
                {
                    # https://github.com/easybuilders/easybuild/wiki/OS_flavor_name_version
                    HttpHeaderKeys.Platform: platform.platform(),
                    HttpHeaderKeys.Python: platform.python_version(),
                    # https://stackoverflow.com/questions/47608532/how-to-detect-from-within-python-whether-packages-are-managed-with-conda
                    HttpHeaderKeys.Conda: str(
                        os.path.exists(
                            os.path.join(sys.prefix, "conda-meta", "history")
                        )
                    ),
                    # https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook
                    HttpHeaderKeys.Notebook: str("ipykernel" in sys.modules),
                    HttpHeaderKeys.ClientSession: uuid.uuid4().hex,
                }
            )
        except Exception:
            pass

        return s
Esempio n. 13
0
class ThreadLocalWrapperTest(unittest.TestCase):
    def setUp(self):
        self.wrapper = ThreadLocalWrapper(
            lambda: (os.getpid(), threading.current_thread().ident))

    def _store_id(self):
        self.thread_id = self.wrapper.get()

    def _send_id(self, queue):
        queue.put(self.wrapper.get())

    def test_thread_thread(self):
        main_thread_id = self.wrapper.get()
        assert main_thread_id == self.wrapper.get()

        thread = threading.Thread(target=self._store_id)
        thread.start()
        thread.join()
        assert main_thread_id != self.thread_id

    # Note on Windows: fork is not available so multiprocessing pickles the multiprocessing
    # function and arguments. ThreadLocalWrapper isn't picklable, so the following tests
    # can't work on Windows. But the problem it solves for multiprocessing also doesn't
    # exist there.

    @unittest.skipIf(sys.platform.startswith("win"),
                     "forking not a concern on Windows")
    def test_wrapper_process(self):
        main_thread_id = self.wrapper.get()
        thread = threading.Thread(target=self._store_id)
        thread.start()
        thread.join()
        assert main_thread_id != self.thread_id

        queue = multiprocessing.Queue()
        process = multiprocessing.Process(target=self._send_id, args=(queue, ))
        process.start()
        process_id = queue.get()
        process.join()
        assert main_thread_id != process_id
        assert self.thread_id != process_id

    @unittest.skipIf(sys.platform.startswith("win"),
                     "forking not a concern on Windows")
    def test_wrapper_unused_in_main_process(self):
        queue = multiprocessing.Queue()
        process = multiprocessing.Process(target=self._send_id, args=(queue, ))
        process.start()
        process_id = queue.get()
        process.join()
        assert process_id != self.wrapper.get()

    @unittest.skipIf(sys.platform.startswith("win"),
                     "forking not a concern on Windows")
    def test_fork_from_fork(self):
        # A gross edge case discovered by Clark: if a process is forked from a forked process
        # things will go awry if we hadn't initialized the internal threading.local's pid.
        def fork_another(queue):
            queue.put(self.wrapper.get())
            process3 = multiprocessing.Process(target=self._send_id,
                                               args=(queue, ))
            process3.start()
            process3.join()

        process1_id = self.wrapper.get()
        queue = multiprocessing.Queue()
        process = multiprocessing.Process(target=fork_another, args=(queue, ))
        process.start()
        process2_id = queue.get()
        process3_id = queue.get()
        process.join()
        assert process1_id != process2_id
        assert process2_id != process3_id
        assert process1_id != process3_id
Esempio n. 14
0
 def setUp(self):
     self.wrapper = ThreadLocalWrapper(
         lambda: (os.getpid(), threading.current_thread().ident))
Esempio n. 15
0
    def __init__(self, domain="https://accounts.descarteslabs.com",
                 scope=None, leeway=500, token_info_path=DEFAULT_TOKEN_INFO_PATH,
                 client_id=None, client_secret=None, jwt_token=None, refresh_token=None):
        """
        Helps retrieve JWT from a client id and refresh token for cli usage.
        :param domain: endpoint for auth0
        :param scope: the JWT fields to be included
        :param leeway: JWT expiration leeway
        :param token_info_path: path to a JSON file optionally holding auth information
        :param client_id: JWT client id
        :param client_secret: JWT client secret
        :param jwt_token: the JWT token, if we already have one
        :param refresh_token: the refresh token
        """
        self.token_info_path = token_info_path

        token_info = {}
        if self.token_info_path:
            try:
                with open(self.token_info_path) as fp:
                    token_info = json.load(fp)
            except (IOError, ValueError):
                pass

        self.client_id = next(
            (x for x in (
                client_id,
                os.environ.get('DESCARTESLABS_CLIENT_ID'),
                os.environ.get('CLIENT_ID'),
                token_info.get('client_id')
            ) if x is not None), None)

        self.client_secret = next(
            (x for x in (
                client_secret,
                os.environ.get('DESCARTESLABS_CLIENT_SECRET'),
                os.environ.get('CLIENT_SECRET'),
                token_info.get('client_secret')
            ) if x is not None), None)

        self.refresh_token = next(
            (x for x in (
                refresh_token,
                os.environ.get('DESCARTESLABS_REFRESH_TOKEN'),
                token_info.get('refresh_token')
            ) if x is not None), None)

        self._token = next(
            (x for x in (
                jwt_token,
                os.environ.get('DESCARTESLABS_TOKEN'),
                token_info.get('JWT_TOKEN'),
                token_info.get('jwt_token')
            ) if x is not None), None)

        self.scope = next(
            (x for x in (
                scope,
                token_info.get('scope')
            ) if x is not None), None)

        if token_info:
            # If the token was read from a path but environment variables were set, we may need
            # to reset the token.
            client_id_changed = token_info.get('client_id', None) != self.client_id
            client_secret_changed = token_info.get('client_secret', None) != self.client_secret
            refresh_token_changed = token_info.get('refresh_token', None) != self.refresh_token

            if client_id_changed or client_secret_changed or refresh_token_changed:
                self._token = None

        self._namespace = None
        self._session = ThreadLocalWrapper(self.build_session)
        self.domain = domain
        self.leeway = leeway
Esempio n. 16
0
class Service(object):
    """The default Descartes Labs HTTP Service used to communicate with its servers.

    This service has a default timeout and retry policy that retries HTTP requests
    depending on the timeout and HTTP status code that was returned.  This is based
    on the `requests timeouts
    <https://requests.readthedocs.io/en/master/user/advanced/#timeouts>`_
    and the `urllib3 retry object
    <https://urllib3.readthedocs.io/en/latest/reference/urllib3.util.html#urllib3.util.retry.Retry>`_.

    The default timeouts are set to 9.5 seconds for establishing a connection (slightly
    larger than a multiple of 3, which is the TCP default packet retransmission window),
    and 30 seconds for reading a response.

    The default retry logic retries up to 3 times total, a maximum of 2 for establishing
    a connection, 2 for reading a response, and 2 for unexpected HTTP status codes.
    The backoff_factor is a random number between 1 and 3, but will never be more
    than 2 minutes.  The unexpected HTTP status codes that will be retried are ``500``,
    ``502``, ``503``, and ``504`` for any of the HTTP requests.

    Parameters
    ----------
    url: str
        The URL prefix to use for communication with the Descartes Labs server.
    token: str, optional
        Deprecated.
    auth: Auth, optional
        A Descartes Labs :py:class:`~descarteslabs.client.auth.Auth` instance.  If not
        provided, a default one will be instantiated.
    retries: int or urllib3.util.retry.Retry
        If a number, it's the number of retries that will be attempled.  If a
        :py:class:`urllib3.util.retry.Retry` instance, it will determine the retry
        behavior.  If not provided, the default retry policy as described above will
        be used.
    session_class: class
        The session class to use when instantiating the session.  This must be a derived
        class from :py:class:`Session`.  If not provided, the default session class
        is used.  You can register a default session class with
        :py:meth:`Service.set_default_session_class`.

    Raises
    ------
    TypeError
        If you try to use a session class that is not derived from :py:class:`Session`.
    """

    # https://requests.readthedocs.io/en/master/user/advanced/#timeouts
    CONNECT_TIMEOUT = 9.5
    READ_TIMEOUT = 30

    TIMEOUT = (CONNECT_TIMEOUT, READ_TIMEOUT)

    RETRY_CONFIG = Retry(
        total=3,
        connect=2,
        read=2,
        status=2,
        backoff_factor=random.uniform(1, 3),
        method_whitelist=frozenset([
            HttpRequestMethod.HEAD,
            HttpRequestMethod.TRACE,
            HttpRequestMethod.GET,
            HttpRequestMethod.POST,
            HttpRequestMethod.PUT,
            HttpRequestMethod.PATCH,
            HttpRequestMethod.OPTIONS,
            HttpRequestMethod.DELETE,
        ]),
        status_forcelist=[
            HttpStatusCode.InternalServerError,
            HttpStatusCode.BadGateway,
            HttpStatusCode.ServiceUnavailable,
            HttpStatusCode.GatewayTimeout,
        ],
    )

    # We share an adapter (one per thread/process) among all clients to take advantage
    # of the single underlying connection pool.
    ADAPTER = ThreadLocalWrapper(
        lambda: HTTPAdapter(max_retries=Service.RETRY_CONFIG))

    _session_class = Session

    # List of attributes that will be included in state for pickling.
    # Subclasses can extend this attribute list.
    __attrs__ = ["auth", "base_url", "_session_class", "RETRY_CONFIG"]

    @classmethod
    def set_default_session_class(cls, session_class):
        """Set the default session class for :py:class:`Service`.

        The default session is used for any :py:class:`Service` that is instantiated
        without specifying the session class.

        Parameters
        ----------
        session_class: class
            The session class to use when instantiating the session.  This must be the
            class :py:class:`Session` itself or a derived class from
            :py:class:`Session`.
        """

        if not issubclass(session_class, Session):
            raise TypeError(
                "The session class must be a subclass of {}.".format(Session))

        cls._session_class = session_class

    @classmethod
    def get_default_session_class(cls):
        """Get the default session class for :py:class:`Service`.

        Returns
        -------
        Session
            The default session class, which is :py:class:`Session` itself or a derived
            class from :py:class:`Session`.
        """

        return cls._session_class

    def __init__(self,
                 url,
                 token=None,
                 auth=None,
                 retries=None,
                 session_class=None):
        if auth is None:
            auth = Auth()

        if token is not None:
            warn(
                "setting token at service level will be removed in future",
                FutureWarning,
            )
            auth._token = token

        self.auth = auth
        self.base_url = url

        if retries is None:
            self._adapter = self.ADAPTER
        else:
            self.RETRY_CONFIG = retries
            self._init_adapter()

        if session_class is not None:
            # Overwrite the default session class
            if not issubclass(session_class, Session):
                raise TypeError(
                    "The session class must be a subclass of {}.".format(
                        Session))

            self._session_class = session_class

        self._init_session()

    def _init_adapter(self):
        self._adapter = ThreadLocalWrapper(
            lambda: HTTPAdapter(max_retries=self.RETRY_CONFIG))

    def _init_session(self):
        # Sessions can't be shared across threads or processes because the underlying
        # SSL connection pool can't be shared. We create them thread-local to avoid
        # intractable exceptions when users naively share clients e.g. when using
        # multiprocessing.
        self._session = ThreadLocalWrapper(self._build_session)

    @property
    def token(self):
        """str: The bearer token used in the requests."""
        return self.auth.token

    @token.setter
    def token(self, token):
        """str: Deprecated"""
        self.auth._token = token

    @property
    def session(self):
        """Session: The session instance used by this service."""
        session = self._session.get()
        auth = add_bearer(self.token)
        if session.headers.get(HttpHeaderKeys.Authorization) != auth:
            session.headers[HttpHeaderKeys.Authorization] = auth

        return session

    def _build_session(self):
        session = self._session_class(self.base_url, timeout=self.TIMEOUT)
        session.initialize()

        adapter = self._adapter.get()
        session.mount(HttpMountProtocol.HTTPS, adapter)
        session.mount(HttpMountProtocol.HTTP, adapter)

        session.headers.update({
            HttpHeaderKeys.ContentType:
            HttpHeaderValues.ApplicationJson,
            HttpHeaderKeys.UserAgent:
            "{}/{}".format(HttpHeaderValues.DlPython, __version__),
        })

        try:
            session.headers.update({
                # https://github.com/easybuilders/easybuild/wiki/OS_flavor_name_version
                HttpHeaderKeys.Platform:
                platform.platform(),
                HttpHeaderKeys.Python:
                platform.python_version(),
                # https://stackoverflow.com/questions/47608532/how-to-detect-from-within-python-whether-packages-are-managed-with-conda
                HttpHeaderKeys.Conda:
                str(
                    os.path.exists(
                        os.path.join(sys.prefix, "conda-meta", "history"))),
                # https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook
                HttpHeaderKeys.Notebook:
                str("ipykernel" in sys.modules),
                HttpHeaderKeys.ClientSession:
                uuid.uuid4().hex,
            })
        except Exception:
            pass

        return session

    def __getstate__(self):
        return dict((attr, getattr(self, attr)) for attr in self.__attrs__)

    def __setstate__(self, state):
        for name, value in state.items():
            setattr(self, name, value)

        self._init_adapter()
        self._init_session()
Esempio n. 17
0
 def _init_adapter(self):
     self._adapter = ThreadLocalWrapper(
         lambda: HTTPAdapter(max_retries=self.RETRY_CONFIG))
Esempio n. 18
0
    def __init__(
        self,
        domain="https://accounts.descarteslabs.com",
        scope=None,
        leeway=500,
        token_info_path=DEFAULT_TOKEN_INFO_PATH,
        client_id=None,
        client_secret=None,
        jwt_token=None,
        refresh_token=None,
    ):
        """
        Helps retrieve JWT from a client id and refresh token for cli usage.

        :param str domain: The endpoint for auth0
        :type scope: list(str) or None
        :param scope: The JWT fields to be included
        :param int leeway: JWT expiration leeway
        :type token_info_path: str or None
        :param token_info_path: Path to a JSON file optionally holding auth information
        :type client_id: str or None
        :param client_id: JWT client id
        :type client_secret: str or None
        :param client_secret: JWT client secret
        :type jwt_token: str or None
        :param jwt_token: The JWT token, if we already have one
        :type refresh_token: str or None
        :param refresh_token: The refresh token
        """
        self.token_info_path = token_info_path

        token_info = {}
        if self.token_info_path:
            try:
                with open(self.token_info_path) as fp:
                    token_info = json.load(fp)
            except (IOError, ValueError):
                pass

        self.client_id = next(
            (x for x in (
                client_id,
                os.environ.get("DESCARTESLABS_CLIENT_ID"),
                os.environ.get("CLIENT_ID"),
                token_info.get("client_id"),
            ) if x is not None),
            None,
        )

        self.client_secret = next(
            (x for x in (
                client_secret,
                os.environ.get("DESCARTESLABS_CLIENT_SECRET"),
                os.environ.get("CLIENT_SECRET"),
                token_info.get("client_secret"),
            ) if x is not None),
            None,
        )

        self.refresh_token = next(
            (x for x in (
                refresh_token,
                os.environ.get("DESCARTESLABS_REFRESH_TOKEN"),
                token_info.get("refresh_token"),
            ) if x is not None),
            None,
        )

        if self.client_secret != self.refresh_token:
            if self.client_secret is not None and self.refresh_token is not None:
                warnings.warn(
                    "Authentication token mismatch: "
                    "client_secret and refresh_token values must match for authentication to work correctly. "
                )

            if self.refresh_token is not None:
                self.client_secret = self.refresh_token
            elif self.client_secret is not None:
                self.refresh_token = self.client_secret

        self._token = next(
            (x for x in (
                jwt_token,
                os.environ.get("DESCARTESLABS_TOKEN"),
                token_info.get("JWT_TOKEN"),
                token_info.get("jwt_token"),
            ) if x is not None),
            None,
        )

        self.scope = next(
            (x for x in (scope, token_info.get("scope")) if x is not None),
            None)

        if token_info:
            # If the token was read from a path but environment variables were set, we may need
            # to reset the token.
            client_id_changed = token_info.get("client_id",
                                               None) != self.client_id
            client_secret_changed = (token_info.get("client_secret", None) !=
                                     self.client_secret)
            refresh_token_changed = (token_info.get("refresh_token", None) !=
                                     self.refresh_token)

            if client_id_changed or client_secret_changed or refresh_token_changed:
                self._token = None

        self._namespace = None
        self._session = ThreadLocalWrapper(self.build_session)
        self.domain = domain
        self.leeway = leeway
Esempio n. 19
0
class ThirdPartyService(object):
    """The default Descartes Labs HTTP Service used for 3rd party servers.

    This service has a default timeout and retry policy that retries HTTP requests
    depending on the timeout and HTTP status code that was returned.  This is based
    on the `requests timeouts
    <https://requests.readthedocs.io/en/master/user/advanced/#timeouts>`_
    and the `urllib3 retry object
    <https://urllib3.readthedocs.io/en/latest/reference/urllib3.util.html#urllib3.util.retry.Retry>`_.

    The default timeouts are set to 9.5 seconds for establishing a connection (slightly
    larger than a multiple of 3, which is the TCP default packet retransmission window),
    and 30 seconds for reading a response.

    The default retry logic retries up to 10 times total, a maximum of 2 for
    establishing a connection.  The backoff_factor is a random number between 1 and
    3, but will never be more than 2 minutes.  The unexpected HTTP status codes that
    will be retried are ``429``, ``500``, ``502``, ``503``, and ``504`` for any of the
    HTTP requests.

    Parameters
    ----------
    url: str
        The URL prefix to use for communication with the 3rd party server.
    session_class: class
        The session class to use when instantiating the session.  This must be a derived
        class from :py:class:`Session`.  If not provided, the default session class
        is used.  You can register a default session class with
        :py:meth:`ThirdPartyService.set_default_session_class`.

    Raises
    ------
    TypeError
        If you try to use a session class that is not derived from :py:class:`Session`.
    """

    CONNECT_TIMEOUT = 9.5
    READ_TIMEOUT = 30
    TIMEOUT = (CONNECT_TIMEOUT, READ_TIMEOUT)

    RETRY_CONFIG = Retry(
        total=10,
        read=2,
        backoff_factor=random.uniform(1, 3),
        method_whitelist=frozenset([
            HttpRequestMethod.HEAD,
            HttpRequestMethod.TRACE,
            HttpRequestMethod.GET,
            HttpRequestMethod.POST,
            HttpRequestMethod.PUT,
            HttpRequestMethod.OPTIONS,
            HttpRequestMethod.DELETE,
        ]),
        status_forcelist=[
            HttpStatusCode.TooManyRequests,
            HttpStatusCode.InternalServerError,
            HttpStatusCode.BadGateway,
            HttpStatusCode.ServiceUnavailable,
            HttpStatusCode.GatewayTimeout,
        ],
    )

    ADAPTER = ThreadLocalWrapper(
        lambda: HTTPAdapter(max_retries=ThirdPartyService.RETRY_CONFIG))

    _session_class = Session

    @classmethod
    def set_default_session_class(cls, session_class=None):
        """Set the default session class for :py:class:`ThirdPartyService`.

        The default session is used for any :py:meth:`ThirdPartyService` that is
        instantiated without specifying the session class.

        Parameters
        ----------
        session_class: class
            The session class to use when instantiating the session.  This must be the
            class :py:class:`Session` itself or a derived class from
            :py:class:`Session`.
        """

        if not issubclass(session_class, Session):
            raise TypeError(
                "The session class must be a subclass of {}.".format(Session))

        cls._session_class = session_class

    @classmethod
    def get_default_session_class(cls):
        """Get the default session class for the :py:class:`ThirdPartyService`.

        Returns
        -------
        Session
            The default session class, which is :py:class:`Session` itself or a derived
            class from :py:class:`Session`.
        """

        return cls._session_class

    def __init__(self, url="", session_class=None):
        self.base_url = url

        if session_class is not None:
            if not issubclass(session_class, Session):
                raise TypeError(
                    "The session class must be a subclass of {}.".format(
                        Session))

            self._session_class = session_class

        self._session = ThreadLocalWrapper(self._build_session)

    @property
    def session(self):
        return self._session.get()

    def _build_session(self):
        session = self._session_class(self.base_url, timeout=self.TIMEOUT)
        session.initialize()

        session.mount(HttpMountProtocol.HTTPS, self.ADAPTER.get())
        session.headers.update({
            HttpHeaderKeys.ContentType:
            HttpHeaderValues.ApplicationOctetStream,
            HttpHeaderKeys.UserAgent:
            "{}/{}".format(HttpHeaderValues.DlPython, __version__),
        })

        return session
Esempio n. 20
0
class Service(object):
    TIMEOUT = (9.5, 30)

    RETRY_CONFIG = Retry(total=5,
                         read=2,
                         backoff_factor=random.uniform(1, 3),
                         method_whitelist=frozenset([
                             'HEAD', 'TRACE', 'GET', 'POST', 'PUT', 'OPTIONS',
                             'DELETE'
                         ]),
                         status_forcelist=[500, 502, 503, 504])

    # We share an adapter (one per thread/process) among all clients to take advantage
    # of the single underlying connection pool.
    ADAPTER = ThreadLocalWrapper(
        lambda: HTTPAdapter(max_retries=Service.RETRY_CONFIG))

    def __init__(self, url, token=None, auth=None):
        if auth is None:
            auth = Auth()

        if token is not None:
            warn("setting token at service level will be removed in future",
                 DeprecationWarning)
            auth._token = token

        self.auth = auth

        self.base_url = url

        # Sessions can't be shared across threads or processes because the underlying
        # SSL connection pool can't be shared. We create them thread-local to avoid
        # intractable exceptions when users naively share clients e.g. when using
        # multiprocessing.
        self._session = ThreadLocalWrapper(self.build_session)

    @property
    def token(self):
        return self.auth.token

    @token.setter
    def token(self, token):
        self.auth._token = token

    @property
    def session(self):
        session = self._session.get()
        if session.headers.get('Authorization') != self.token:
            session.headers['Authorization'] = self.token

        return session

    def build_session(self):
        s = WrappedSession(self.base_url, timeout=self.TIMEOUT)
        s.mount('https://', self.ADAPTER.get())

        s.headers.update({
            "Content-Type": "application/json",
            "User-Agent": "dl-python/{}".format(__version__)
        })

        return s