def _get_token(self, timeout=100): if self.client_id is None: raise AuthError("Could not find CLIENT_ID") if self.client_secret is None: raise AuthError("Could not find CLIENT_SECRET") s = requests.Session() retries = Retry(total=5, backoff_factor=random.uniform(1, 10), method_whitelist=frozenset(['GET', 'POST']), status_forcelist=[429, 500, 502, 503, 504]) s.mount('https://', HTTPAdapter(max_retries=retries)) headers = {"content-type": "application/json"} params = { "scope": " ".join(self.scope), "client_id": self.client_id, "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", "target": self.client_id, "api_type": "app", "refresh_token": self.client_secret } r = s.post(self.domain + "/auth/delegation", headers=headers, data=json.dumps(params), timeout=timeout) if r.status_code != 200: raise OauthError("%s: %s" % (r.status_code, r.text)) data = r.json() self._token = data['id_token'] token_info = {} if self.token_info_path: try: with open(self.token_info_path) as fp: token_info = json.load(fp) except (IOError, ValueError): pass token_info['jwt_token'] = self._token if self.token_info_path: token_info_directory = os.path.dirname(self.token_info_path) makedirs_if_not_exists(token_info_directory) try: with open(self.token_info_path, 'w+') as fp: json.dump(token_info, fp) os.chmod(self.token_info_path, stat.S_IRUSR | stat.S_IWUSR) except IOError as e: warnings.warn('failed to save token: {}'.format(e))
def payload(self): """ Gets the token payload. :rtype: dict :return: Dictionary containing the fields specified by scope, which may include: .. highlight:: none :: name: The name of the user. groups: Groups to which the user belongs. org: The organization to which the user belongs. email: The email address of the user. email_verified: True if the user's email has been verified. sub: The user identifier. exp: The expiration time of the token, in seconds since the start of the unix epoch. :raises ~descarteslabs.client.exceptions.AuthError: Raised when incomplete information has been provided. :raises ~descarteslabs.client.exceptions.OauthError: Raised when a token cannot be obtained or refreshed. """ if self._token is None: self._get_token() if isinstance(self._token, six.text_type): token = self._token.encode("utf-8") else: token = self._token try: claims = token.split(b".")[1] return json.loads(base64url_decode(claims).decode("utf-8")) except ( IndexError, UnicodeDecodeError, binascii.Error, json.JSONDecodeError, ) as e: raise AuthError("Unable to read token: {}".format(e))
def _get_token(self, timeout=100): if self.client_id is None: raise AuthError("Could not find client_id") if self.client_secret is None and self.refresh_token is None: raise AuthError("Could not find client_secret or refresh token") if self.client_id in ["ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c" ]: # TODO(justin) remove legacy handling # TODO (justin) insert deprecation warning if self.scope is None: scope = ["openid", "name", "groups", "org", "email"] else: scope = self.scope params = { "scope": " ".join(scope), "client_id": self.client_id, "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", "target": self.client_id, "api_type": "app", "refresh_token": self.refresh_token, } else: params = { "client_id": self.client_id, "grant_type": "refresh_token", "refresh_token": self.refresh_token, } if self.scope is not None: params["scope"] = " ".join(self.scope) r = self.session.post(self.domain + "/token", json=params, timeout=timeout) if r.status_code != 200: raise OauthError("%s: %s" % (r.status_code, r.text)) data = r.json() access_token = data.get("access_token") id_token = data.get( "id_token") # TODO(justin) remove legacy id_token usage if access_token is not None: self._token = access_token elif id_token is not None: self._token = id_token else: raise OauthError("could not retrieve token") token_info = {} if self.token_info_path: try: with open(self.token_info_path) as fp: token_info = json.load(fp) except (IOError, ValueError): pass token_info["jwt_token"] = self._token if self.token_info_path: token_info_directory = os.path.dirname(self.token_info_path) makedirs_if_not_exists(token_info_directory) try: with open(self.token_info_path, "w+") as fp: json.dump(token_info, fp) os.chmod(self.token_info_path, stat.S_IRUSR | stat.S_IWUSR) except IOError as e: warnings.warn("failed to save token: {}".format(e))
class TestAuth(unittest.TestCase): def tearDown(self): warnings.resetwarnings() def test_auth_client_refresh_match(self): with warnings.catch_warnings(record=True) as w: auth = Auth( client_id="client_id", client_secret="secret", refresh_token="mismatched_refresh_token", ) assert 1 == len(w) assert "mismatched_refresh_token" == auth.refresh_token assert "mismatched_refresh_token" == auth.client_secret @responses.activate def test_get_token(self): responses.add( responses.POST, "https://accounts.descarteslabs.com/token", json=dict(access_token="access_token"), status=200, ) auth = Auth(token_info_path=None, client_secret="client_secret", client_id="client_id") auth._get_token() assert "access_token" == auth._token @responses.activate def test_get_token_legacy(self): responses.add( responses.POST, "https://accounts.descarteslabs.com/token", json=dict(id_token="id_token"), status=200, ) auth = Auth(token_info_path=None, client_secret="client_secret", client_id="client_id") auth._get_token() assert "id_token" == auth._token @patch("descarteslabs.client.auth.Auth.payload", new=dict(sub="asdf")) def test_get_namespace(self): auth = Auth(token_info_path=None, client_secret="client_secret", client_id="client_id") assert auth.namespace == "3da541559918a808c2402bba5012f6c60b27661c" def test_init_token_no_path(self): auth = Auth(jwt_token="token", token_info_path=None, client_id="foo") assert "token" == auth._token @responses.activate def test_get_token_schema_internal_only(self): responses.add_callback( responses.POST, "https://accounts.descarteslabs.com/token", callback=token_response_callback, ) auth = Auth(token_info_path=None, refresh_token="refresh_token", client_id="client_id") auth._get_token() assert "access_token" == auth._token auth = Auth(token_info_path=None, client_secret="refresh_token", client_id="client_id") auth._get_token() assert "access_token" == auth._token @responses.activate def test_get_token_schema_legacy_internal_only(self): responses.add_callback( responses.POST, "https://accounts.descarteslabs.com/token", callback=token_response_callback, ) auth = Auth( token_info_path=None, client_secret="client_secret", client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c", ) auth._get_token() assert "id_token" == auth._token @patch("descarteslabs.client.auth.Auth._get_token") def test_token(self, _get_token): auth = Auth( token_info_path=None, client_secret="client_secret", client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c", ) token = b".".join( (base64.b64encode(to_bytes(p)) for p in ["header", json.dumps(dict(exp=9999999999)), "sig"])) auth._token = token assert auth.token == token _get_token.assert_not_called() @patch("descarteslabs.client.auth.Auth._get_token") def test_token_expired(self, _get_token): auth = Auth( token_info_path=None, client_secret="client_secret", client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c", ) token = b".".join( (base64.b64encode(to_bytes(p)) for p in ["header", json.dumps(dict(exp=0)), "sig"])) auth._token = token assert auth.token == token _get_token.assert_called_once() @patch("descarteslabs.client.auth.Auth._get_token", side_effect=AuthError("error")) def test_token_expired_autherror(self, _get_token): auth = Auth( token_info_path=None, client_secret="client_secret", client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c", ) token = b".".join( (base64.b64encode(to_bytes(p)) for p in ["header", json.dumps(dict(exp=0)), "sig"])) auth._token = token with pytest.raises(AuthError): auth.token _get_token.assert_called_once() @patch("descarteslabs.client.auth.Auth._get_token", side_effect=AuthError("error")) def test_token_in_leeway_autherror(self, _get_token): auth = Auth( token_info_path=None, client_secret="client_secret", client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c", ) exp = (datetime.datetime.utcnow() - datetime.datetime(1970, 1, 1)).total_seconds() + auth.leeway / 2 token = b".".join( (base64.b64encode(to_bytes(p)) for p in ["header", json.dumps(dict(exp=exp)), "sig"])) auth._token = token assert auth.token == token _get_token.assert_called_once() def test_auth_init_env_vars(self): warnings.simplefilter("ignore") environ = dict( CLIENT_SECRET="secret_bar", CLIENT_ID="id_bar", DESCARTESLABS_CLIENT_SECRET="secret_foo", DESCARTESLABS_CLIENT_ID="id_foo", DESCARTESLABS_REFRESH_TOKEN="refresh_foo", ) # should work with direct var with patch.dict("descarteslabs.client.auth.auth.os.environ", environ): auth = Auth( client_id="client_id", client_secret="client_secret", refresh_token="client_secret", jwt_token="jwt_token", ) assert auth.client_secret == "client_secret" assert auth.client_id == "client_id" # should work with namespaced env vars with patch.dict("descarteslabs.client.auth.auth.os.environ", environ): auth = Auth() # when refresh_token and client_secret do not match, # the Auth implementation sets both to the value of # refresh_token assert auth.client_secret == environ.get( "DESCARTESLABS_REFRESH_TOKEN") assert auth.client_id == environ.get("DESCARTESLABS_CLIENT_ID") # remove the namespaced ones, except the refresh token because # Auth does not recognize a REFRESH_TOKEN environment variable # and removing it from the dictionary would result in non-deterministic # results based on the token_info.json file on the test runner disk environ.pop("DESCARTESLABS_CLIENT_SECRET") environ.pop("DESCARTESLABS_CLIENT_ID") # should fallback to legacy env vars with patch.dict("descarteslabs.client.auth.auth.os.environ", environ): auth = Auth() assert auth.client_secret == environ.get( "DESCARTESLABS_REFRESH_TOKEN") assert auth.client_id == environ.get("CLIENT_ID")
class TestAuth(unittest.TestCase): @responses.activate def test_get_token(self): responses.add( responses.POST, "https://accounts.descarteslabs.com/token", json=dict(access_token="access_token"), status=200, ) auth = Auth(token_info_path=None, client_secret="client_secret", client_id="client_id") auth._get_token() self.assertEqual("access_token", auth._token) @responses.activate def test_get_token_legacy(self): responses.add( responses.POST, "https://accounts.descarteslabs.com/token", json=dict(id_token="id_token"), status=200, ) auth = Auth(token_info_path=None, client_secret="client_secret", client_id="client_id") auth._get_token() self.assertEqual("id_token", auth._token) @patch("descarteslabs.client.auth.Auth.payload", new=dict(sub="asdf")) def test_get_namespace(self): auth = Auth(token_info_path=None, client_secret="client_secret", client_id="client_id") self.assertEqual(auth.namespace, "3da541559918a808c2402bba5012f6c60b27661c") def test_init_token_no_path(self): auth = Auth(jwt_token="token", token_info_path=None, client_id="foo") self.assertEqual("token", auth._token) @responses.activate def test_get_token_schema_internal_only(self): responses.add_callback( responses.POST, "https://accounts.descarteslabs.com/token", callback=token_response_callback, ) auth = Auth(token_info_path=None, refresh_token="refresh_token", client_id="client_id") auth._get_token() self.assertEqual("access_token", auth._token) auth = Auth(token_info_path=None, client_secret="refresh_token", client_id="client_id") auth._get_token() self.assertEqual("access_token", auth._token) @responses.activate def test_get_token_schema_legacy_internal_only(self): responses.add_callback( responses.POST, "https://accounts.descarteslabs.com/token", callback=token_response_callback, ) auth = Auth( token_info_path=None, client_secret="client_secret", client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c", ) auth._get_token() self.assertEqual("id_token", auth._token) @patch("descarteslabs.client.auth.Auth._get_token") def test_token(self, _get_token): auth = Auth( token_info_path=None, client_secret="client_secret", client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c", ) token = b".".join( (base64.b64encode(to_bytes(p)) for p in ["header", json.dumps(dict(exp=9999999999)), "sig"])) auth._token = token self.assertEqual(auth.token, token) _get_token.assert_not_called() @patch("descarteslabs.client.auth.Auth._get_token") def test_token_expired(self, _get_token): auth = Auth( token_info_path=None, client_secret="client_secret", client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c", ) token = b".".join( (base64.b64encode(to_bytes(p)) for p in ["header", json.dumps(dict(exp=0)), "sig"])) auth._token = token self.assertEqual(auth.token, token) _get_token.assert_called_once() @patch("descarteslabs.client.auth.Auth._get_token", side_effect=AuthError("error")) def test_token_expired_autherror(self, _get_token): auth = Auth( token_info_path=None, client_secret="client_secret", client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c", ) token = b".".join( (base64.b64encode(to_bytes(p)) for p in ["header", json.dumps(dict(exp=0)), "sig"])) auth._token = token with self.assertRaises(AuthError): auth.token _get_token.assert_called_once() @patch("descarteslabs.client.auth.Auth._get_token", side_effect=AuthError("error")) def test_token_in_leeway_autherror(self, _get_token): auth = Auth( token_info_path=None, client_secret="client_secret", client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c", ) exp = (datetime.datetime.utcnow() - datetime.datetime(1970, 1, 1)).total_seconds() + auth.leeway / 2 token = b".".join( (base64.b64encode(to_bytes(p)) for p in ["header", json.dumps(dict(exp=exp)), "sig"])) auth._token = token self.assertEqual(auth.token, token) _get_token.assert_called_once() def test_auth_init_env_vars(self): environ = dict( CLIENT_SECRET="secret_bar", CLIENT_ID="id_bar", DESCARTESLABS_CLIENT_SECRET="secret_foo", DESCARTESLABS_CLIENT_ID="id_foo", ) # should work with direct var with patch.dict("descarteslabs.client.auth.auth.os.environ", environ): auth = Auth( client_id="client_id", client_secret="client_secret", jwt_token="jwt_token", ) self.assertEqual(auth.client_secret, "client_secret") self.assertEqual(auth.client_id, "client_id") # should work with namespaced env vars with patch.dict("descarteslabs.client.auth.auth.os.environ", environ): auth = Auth() self.assertEqual(auth.client_secret, environ.get("DESCARTESLABS_CLIENT_SECRET")) self.assertEqual(auth.client_id, environ.get("DESCARTESLABS_CLIENT_ID")) # remove the namespaced ones environ.pop("DESCARTESLABS_CLIENT_SECRET") environ.pop("DESCARTESLABS_CLIENT_ID") # should fallback to legacy env vars with patch.dict("descarteslabs.client.auth.auth.os.environ", environ): auth = Auth() self.assertEqual(auth.client_secret, environ.get("CLIENT_SECRET")) self.assertEqual(auth.client_id, environ.get("CLIENT_ID"))