def test_evicts_invalid_refresh_token(): """when AAD rejects a refresh token, the client should evict that token from its cache""" tenant_id = "tenant-id" client_id = "client-id" invalid_token = "invalid-refresh-token" cache = TokenCache() cache.add({"response": build_aad_response(uid="id1", utid="tid1", access_token="*", refresh_token=invalid_token)}) cache.add({"response": build_aad_response(uid="id2", utid="tid2", access_token="*", refresh_token="...")}) assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 2 assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": invalid_token})) == 1 def send(request, **_): assert request.data["refresh_token"] == invalid_token return mock_response(json_payload={"error": "invalid_grant"}, status_code=400) transport = Mock(send=Mock(wraps=send)) client = AadClient(tenant_id, client_id, transport=transport, cache=cache) with pytest.raises(ClientAuthenticationError): client.obtain_token_by_refresh_token(scopes=("scope",), refresh_token=invalid_token) assert transport.send.call_count == 1 assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 1 assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": invalid_token})) == 0
class AuthnClientBase(object): """Sans I/O authentication client methods""" def __init__(self, auth_url, **kwargs): # type: (str, Mapping[str, Any]) -> None if not auth_url: raise ValueError("auth_url should be the URL of an OAuth endpoint") super(AuthnClientBase, self).__init__() self._auth_url = auth_url self._cache = TokenCache() def get_cached_token(self, scopes): # type: (Iterable[str]) -> Optional[str] tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, list(scopes)) for token in tokens: if all((scope in token["target"] for scope in scopes)): if int(token["expires_on"]) - 300 > int(time()): return token["secret"] return None def _deserialize_and_cache_token(self, response, scopes): # type: (PipelineResponse, Iterable[str]) -> str try: if "deserialized_data" in response.context: payload = response.context["deserialized_data"] else: payload = response.http_response.text() token = payload["access_token"] # these values are strings in IMDS responses but msal.TokenCache requires they be integers # https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/55 if payload.get("expires_in"): payload["expires_in"] = int(payload["expires_in"]) if payload.get("ext_expires_in"): payload["ext_expires_in"] = int(payload["ext_expires_in"]) self._cache.add({"response": payload, "scope": scopes}) return token except KeyError: raise AuthenticationError( "Unexpected authentication response: {}".format(payload)) except Exception as ex: raise AuthenticationError("Authentication failed: {}".format( str(ex))) def _prepare_request(self, method="POST", headers=None, form_data=None, params=None): # type: (Optional[str], Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> HttpRequest request = HttpRequest(method, self._auth_url, headers=headers) if form_data: request.headers[ "Content-Type"] = "application/x-www-form-urlencoded" request.set_formdata_body(form_data) if params: request.format_parameters(params) return request
def test_persistent_cache_multiple_clients(cert_path, cert_password): """the credential shouldn't use tokens issued to other service principals""" access_token_a = "token a" access_token_b = "not " + access_token_a transport_a = msal_validating_transport( requests=[Request()], responses=[ mock_response(json_payload=build_aad_response( access_token=access_token_a)) ]) transport_b = msal_validating_transport( requests=[Request()], responses=[ mock_response(json_payload=build_aad_response( access_token=access_token_b)) ]) cache = TokenCache() with patch( "azure.identity._internal.msal_credentials._load_persistent_cache" ) as mock_cache_loader: mock_cache_loader.return_value = Mock(wraps=cache) credential_a = CertificateCredential( "tenant", "client-a", cert_path, password=cert_password, transport=transport_a, cache_persistence_options=TokenCachePersistenceOptions(), ) assert mock_cache_loader.call_count == 1, "credential should load the persistent cache" credential_b = CertificateCredential( "tenant", "client-b", cert_path, password=cert_password, transport=transport_b, cache_persistence_options=TokenCachePersistenceOptions(), ) assert mock_cache_loader.call_count == 2, "credential should load the persistent cache" # A caches a token scope = "scope" token_a = credential_a.get_token(scope) assert token_a.token == access_token_a assert transport_a.send.call_count == 3 # two MSAL discovery requests, one token request # B should get a different token for the same scope token_b = credential_b.get_token(scope) assert token_b.token == access_token_b assert transport_b.send.call_count == 3 assert len(cache.find(TokenCache.CredentialType.ACCESS_TOKEN)) == 2
def test_writes_to_cache(): """the credential should write tokens it acquires to the cache""" scope = "scope" expected_access_token = "access token" first_refresh_token = "first refresh token" second_refresh_token = "second refresh token" username = "******" uid = "uid" utid = "utid" account = get_account_event(username=username, uid=uid, utid=utid, refresh_token=first_refresh_token) cache = TokenCache() cache.add(account) transport = validating_transport( requests=[Request(required_data={"refresh_token": first_refresh_token})], # credential redeems refresh token responses=[ mock_response( json_payload=build_aad_response( # AAD responds with an access token and new refresh token uid=uid, utid=utid, access_token=expected_access_token, refresh_token=second_refresh_token, id_token=build_id_token( aud=DEVELOPER_SIGN_ON_CLIENT_ID, object_id=uid, tenant_id=utid, username=username ), ) ) ], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport) token = credential.get_token(scope) assert token.token == expected_access_token # access token should be in the cache, and another instance should retrieve it credential = SharedTokenCacheCredential( _cache=cache, transport=Mock(send=Mock(side_effect=Exception("the credential should return a cached token"))) ) token = credential.get_token(scope) assert token.token == expected_access_token # and the credential should have updated the cached refresh token second_access_token = "second access token" transport = validating_transport( requests=[Request(required_data={"refresh_token": second_refresh_token})], responses=[mock_response(json_payload=build_aad_response(access_token=second_access_token))], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport) token = credential.get_token("some other " + scope) assert token.token == second_access_token # verify the credential didn't add a new cache entry assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 1
async def test_cache_multiple_clients(): """the credential shouldn't use tokens issued to other service principals""" access_token_a = "token a" access_token_b = "not " + access_token_a transport_a = async_validating_transport( requests=[Request()], responses=[ mock_response(json_payload=build_aad_response( access_token=access_token_a)) ]) transport_b = async_validating_transport( requests=[Request()], responses=[ mock_response(json_payload=build_aad_response( access_token=access_token_b)) ]) cache = TokenCache() with patch(ClientSecretCredential.__module__ + "._load_persistent_cache") as mock_cache_loader: mock_cache_loader.return_value = Mock(wraps=cache) credential_a = ClientSecretCredential( "tenant", "client-a", "secret", transport=transport_a, cache_persistence_options=TokenCachePersistenceOptions(), ) assert mock_cache_loader.call_count == 1, "credential should load the persistent cache" credential_b = ClientSecretCredential( "tenant", "client-b", "secret", transport=transport_b, cache_persistence_options=TokenCachePersistenceOptions(), ) assert mock_cache_loader.call_count == 2, "credential should load the persistent cache" # A caches a token scope = "scope" token_a = await credential_a.get_token(scope) assert token_a.token == access_token_a assert transport_a.send.call_count == 1 # B should get a different token for the same scope token_b = await credential_b.get_token(scope) assert token_b.token == access_token_b assert transport_b.send.call_count == 1 assert len(cache.find(TokenCache.CredentialType.ACCESS_TOKEN)) == 2
class AuthnClientBase(object): """Sans I/O authentication client methods""" def __init__(self, auth_url, **kwargs): # type: (str, Mapping[str, Any]) -> None if not auth_url: raise ValueError("auth_url should be the URL of an OAuth endpoint") super(AuthnClientBase, self).__init__() self._auth_url = auth_url self._cache = TokenCache() def get_cached_token(self, scopes): # type: (Iterable[str]) -> Optional[AccessToken] tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, list(scopes)) for token in tokens: if all((scope in token["target"] for scope in scopes)): expires_on = int(token["expires_on"]) if expires_on - 300 > int(time.time()): return AccessToken(token["secret"], expires_on) return None def _deserialize_and_cache_token(self, response, scopes, request_time): # type: (PipelineResponse, Iterable[str], int) -> AccessToken # ContentDecodePolicy sets this, and should have raised if it couldn't deserialize the response payload = response.context[ContentDecodePolicy.CONTEXT_NAME] if not payload or "access_token" not in payload or not ( "expires_in" in payload or "expires_on" in payload): if payload and "access_token" in payload: payload["access_token"] = "****" raise ClientAuthenticationError( message="Unexpected response '{}'".format(payload)) token = payload["access_token"] # AccessToken wants expires_on as an int expires_on = payload.get("expires_on") or int( payload["expires_in"]) + request_time try: expires_on = int(expires_on) except ValueError: # probably an App Service MSI response, convert it to epoch seconds try: t = self._parse_app_service_expires_on(expires_on) expires_on = calendar.timegm(t) except ValueError: # have a token but don't know when it expires -> treat it as single-use expires_on = request_time # now we have an int expires_on, ensure the cache entry gets it payload["expires_on"] = expires_on self._cache.add({"response": payload, "scope": scopes}) return AccessToken(token, expires_on) @staticmethod def _parse_app_service_expires_on(expires_on): # type: (str) -> struct_time """ Parse expires_on from an App Service MSI response (e.g. "06/19/2019 23:42:01 +00:00") to struct_time. Expects the time is given in UTC (i.e. has offset +00:00). """ if not expires_on.endswith(" +00:00"): raise ValueError( "'{}' doesn't match expected format".format(expires_on)) # parse the string minus the timezone offset return time.strptime(expires_on[:-len(" +00:00")], "%m/%d/%Y %H:%M:%S") # TODO: public, factor out of request_token def _prepare_request(self, method="POST", headers=None, form_data=None, params=None): # type: (Optional[str], Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> HttpRequest request = HttpRequest(method, self._auth_url, headers=headers) if form_data: request.headers[ "Content-Type"] = "application/x-www-form-urlencoded" request.set_formdata_body(form_data) if params: request.format_parameters(params) return request
class AadClientBase(ABC): """Sans I/O methods for AAD clients wrapping MSAL's OAuth client""" def __init__(self, tenant_id, client_id, **kwargs): # type: (str, str, **Any) -> None authority = kwargs.pop("authority", KnownAuthorities.AZURE_PUBLIC_CLOUD) if authority[-1] == "/": authority = authority[:-1] token_endpoint = "https://" + "/".join( (authority, tenant_id, "oauth2/v2.0/token")) config = {"token_endpoint": token_endpoint} self._client = Client(server_configuration=config, client_id=client_id) self._client.session.close() self._client.session = self._get_client_session(**kwargs) self._cache = TokenCache() def get_cached_access_token(self, scopes): # type: (Iterable[str]) -> Optional[AccessToken] tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes)) for token in tokens: expires_on = int(token["expires_on"]) if expires_on - 300 > int(time.time()): return AccessToken(token["secret"], expires_on) return None def get_cached_refresh_tokens(self, scopes): """Assumes all cached refresh tokens belong to the same user""" return self._cache.find(TokenCache.CredentialType.REFRESH_TOKEN, target=list(scopes)) def obtain_token_by_authorization_code(self, code, redirect_uri, scopes, **kwargs): # type: (str, str, Iterable[str], **Any) -> AccessToken fn = functools.partial(self._client.obtain_token_by_authorization_code, code=code, redirect_uri=redirect_uri, **kwargs) return self._obtain_token(scopes, fn, **kwargs) def obtain_token_by_refresh_token(self, refresh_token, scopes, **kwargs): # type: (str, Iterable[str], **Any) -> AccessToken fn = functools.partial(self._client.obtain_token_by_refresh_token, token_item=refresh_token, scope=scopes, rt_getter=lambda token: token["secret"], **kwargs) return self._obtain_token(scopes, fn, **kwargs) def _process_response(self, response, scopes, now): # type: (dict, Iterable[str], int) -> AccessToken _raise_for_error(response) self._cache.add(event={"response": response, "scope": scopes}, now=now) if "expires_on" in response: expires_on = int(response["expires_on"]) elif "expires_in" in response: expires_on = now + int(response["expires_in"]) else: _scrub_secrets(response) raise ClientAuthenticationError( message="Unexpected response from Azure Active Directory: {}". format(response)) return AccessToken(response["access_token"], expires_on) @abc.abstractmethod def _get_client_session(self, **kwargs): pass @abc.abstractmethod def _obtain_token(self, scopes, fn, **kwargs): # type: (Iterable[str], Callable, **Any) -> AccessToken pass
class AuthnClientBase(object): """Sans I/O authentication client methods""" def __init__(self, auth_url, **kwargs): # type: (str, Mapping[str, Any]) -> None if not auth_url: raise ValueError("auth_url should be the URL of an OAuth endpoint") super(AuthnClientBase, self).__init__() self._auth_url = auth_url self._cache = TokenCache() def get_cached_token(self, scopes): # type: (Iterable[str]) -> Optional[AccessToken] tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, list(scopes)) for token in tokens: if all((scope in token["target"] for scope in scopes)): expires_on = int(token["expires_on"]) if expires_on - 300 > int(time.time()): return AccessToken(token["secret"], expires_on) return None def _deserialize_and_cache_token(self, response, scopes, request_time): # type: (PipelineResponse, Iterable[str], int) -> str try: if "deserialized_data" in response.context: payload = response.context["deserialized_data"] else: payload = response.http_response.text() token = payload["access_token"] # these values are strings in IMDS responses but msal.TokenCache requires they be integers # https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/55 expires_in = int(payload.get("expires_in", 0)) if expires_in != 0: payload["expires_in"] = expires_in if "ext_expires_in" in payload: payload["ext_expires_in"] = int(payload["ext_expires_in"]) self._cache.add({"response": payload, "scope": scopes}) # AccessToken contains the token's expires_on time. There are four cases for setting it: # 1. response has expires_on -> AccessToken uses it # 2. response has expires_on and expires_in -> AccessToken uses expires_on # 3. response has only expires_in -> AccessToken uses expires_in + time of request # 4. response has neither expires_on or expires_in -> AccessToken sets expires_on = 0 # (not expecting this case; if it occurs, the token is effectively single-use) expires_on = payload.get("expires_on", 0) return AccessToken(token, expires_on or expires_in + request_time) except KeyError: if "access_token" in payload: payload["access_token"] = "****" raise AuthenticationError( "Unexpected authentication response: {}".format(payload)) except Exception as ex: raise AuthenticationError("Authentication failed: {}".format( str(ex))) # TODO: public, factor out of request_token def _prepare_request(self, method="POST", headers=None, form_data=None, params=None): # type: (Optional[str], Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> HttpRequest request = HttpRequest(method, self._auth_url, headers=headers) if form_data: request.headers[ "Content-Type"] = "application/x-www-form-urlencoded" request.set_formdata_body(form_data) if params: request.format_parameters(params) return request