def test_single_account(): """one cached account, no username specified -> credential should auth that account""" refresh_token = "refresh-token" scope = "scope" account = get_account_event(uid="uid_a", utid="utid", username="******", refresh_token=refresh_token) cache = populated_cache(account) expected_token = "***" transport = validating_transport( requests=[Request(required_data={"refresh_token": refresh_token, "scope": scope})], responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport) token = credential.get_token(scope) assert token.token == expected_token
def test_wsl_fallback(uname, is_wsl): """the credential should invoke powershell.exe to open a browser in WSL when webbrowser.open fails""" auth_uri = "http://localhost" expected_access_token = "**" msal_acquire_token_result = dict( build_aad_response(access_token=expected_access_token, id_token=build_id_token()), id_token_claims=id_token_claims("issuer", "subject", "audience", upn="upn"), ) msal_app = Mock( initiate_auth_code_flow=Mock(return_value={"auth_uri": auth_uri}), acquire_token_by_auth_code_flow=Mock( return_value=msal_acquire_token_result), ) transport = Mock(send=Mock(side_effect=Exception( "this test mocks MSAL, so no request should be sent"))) credential = InteractiveBrowserCredential(_server_class=Mock(), transport=transport) with patch(InteractiveBrowserCredential.__module__ + ".subprocess.call") as subprocess_call: subprocess_call.return_value = 0 with patch(InteractiveBrowserCredential.__module__ + ".platform.uname", lambda: uname): with patch.object(InteractiveBrowserCredential, "_get_app", lambda _: msal_app): with patch(WEBBROWSER_OPEN, lambda _: False): try: token = credential.get_token("scope") except CredentialUnavailableError: assert not is_wsl, "credential should invoke powershell.exe in WSL" return assert is_wsl, "credential should raise CredentialUnavailableError when not in WSL" assert token.token == expected_access_token assert subprocess_call.call_count == 1 args, kwargs = subprocess_call.call_args assert args[0][0] == "powershell.exe" assert auth_uri in args[0][-1] if platform.python_version() >= "3.3": assert "timeout" in kwargs
def test_auth_record_multiple_accounts_for_username(): tenant_id = "tenant-id" client_id = "client-id" authority = "localhost" object_id = "object-id" home_account_id = object_id + "." + tenant_id username = "******" record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) expected_access_token = "****" expected_refresh_token = "**" expected_account = get_account_event(username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token) cache = populated_cache( expected_account, get_account_event( # this account matches all but the record's tenant username, object_id, "different-" + tenant_id, authority=authority, client_id=client_id, refresh_token="not-" + expected_refresh_token, ), ) transport = validating_transport( requests=[ Request(authority=authority, required_data={"refresh_token": expected_refresh_token}) ], responses=[ mock_response(json_payload=build_aad_response( access_token=expected_access_token)) ], ) credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) token = credential.get_token("scope") assert token.token == expected_access_token
def test_authenticate(): client_id = "client-id" environment = "localhost" issuer = "https://" + environment tenant_id = "some-tenant" authority = issuer + "/" + tenant_id access_token = "***" scope = "scope" # mock AAD response with id token object_id = "object-id" home_tenant = "home-tenant-id" username = "******" id_token = build_id_token(aud=client_id, iss=issuer, object_id=object_id, tenant_id=home_tenant, username=username) auth_response = build_aad_response( uid=object_id, utid=home_tenant, access_token=access_token, refresh_token="**", id_token=id_token ) transport = validating_transport( requests=[Request(url_substring=issuer)] * 4, responses=[ get_discovery_response(authority), # instance discovery get_discovery_response(authority), # tenant discovery mock_response(status_code=404), # user realm discovery mock_response(json_payload=auth_response), # token request following authenticate() ], ) credential = UsernamePasswordCredential( username=username, password="******", authority=environment, client_id=client_id, tenant_id=tenant_id, transport=transport, ) record = credential.authenticate(scopes=(scope,)) assert record.authority == environment assert record.home_account_id == object_id + "." + home_tenant assert record.tenant_id == home_tenant assert record.username == username # credential should have a cached access token for the scope passed to authenticate token = credential.get_token(scope) assert token.token == access_token
def test_policies_configurable(): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) transport = validating_transport( requests=[Request()] * 3, responses=[get_discovery_response()] * 2 + [mock_response(json_payload=build_aad_response(access_token="**"))], ) credential = UsernamePasswordCredential("client-id", "username", "password", policies=[policy], transport=transport) credential.get_token("scope") assert policy.on_request.called
def test_multitenant_cache(): client_id = "client-id" scope = "scope" expected_token = "***" tenant_a = "tenant-a" tenant_b = "tenant-b" tenant_c = "tenant-c" authority = "https://localhost/" + tenant_a cache = TokenCache() cache.add({ "response": build_aad_response(access_token=expected_token), "client_id": client_id, "scope": [scope], "token_endpoint": "/".join((authority, tenant_a, "oauth2/v2.0/token")), }) common_args = dict(authority=authority, cache=cache, client_id=client_id) client_a = AadClient(tenant_id=tenant_a, **common_args) client_b = AadClient(tenant_id=tenant_b, **common_args) # A has a cached token token = client_a.get_cached_access_token([scope]) assert token.token == expected_token # which B shouldn't return assert client_b.get_cached_access_token([scope]) is None # but C allows multitenant auth and should therefore return the token from tenant_a when appropriate client_c = AadClient(tenant_id=tenant_c, allow_multitenant_authentication=True, **common_args) assert client_c.get_cached_access_token([scope]) is None token = client_c.get_cached_access_token([scope], tenant_id=tenant_a) assert token.token == expected_token with patch.dict("os.environ", { EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true" }, clear=True): assert client_c.get_cached_access_token([scope], tenant_id=tenant_a) is None
async def test_single_account_matching_tenant_and_username(): """one cached account, tenant and username specified, both match -> credential should auth that account""" upn = "spam@eggs" tenant_id = "tenant-id" refresh_token = "refresh-token" scope = "scope" account = get_account_event(uid="uid_a", utid=tenant_id, username=upn, refresh_token=refresh_token) cache = populated_cache(account) expected_token = "***" transport = async_validating_transport( requests=[Request(required_data={"refresh_token": refresh_token, "scope": scope})], responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport, tenant_id=tenant_id, username=upn) token = await credential.get_token(scope) assert token.token == expected_token
def test_no_user_settings(): """the credential should default to Public Cloud and "organizations" tenant when it can't read VS Code settings""" transport = validating_transport( requests=[ Request(base_url="https://{}/{}".format( AzureAuthorityHosts.AZURE_PUBLIC_CLOUD, "organizations")) ], responses=[ mock_response(json_payload=build_aad_response(access_token="**")) ], ) credential = get_credential(transport=transport) with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): credential.get_token("scope") assert transport.send.call_count == 1
def test_claims_challenge(): """get_token should and authenticate pass any claims challenge to MSAL token acquisition APIs""" msal_acquire_token_result = dict( build_aad_response(access_token="**", id_token=build_id_token()), id_token_claims=id_token_claims("issuer", "subject", "audience", upn="upn"), ) expected_claims = '{"access_token": {"essential": "true"}' transport = Mock(send=Mock(side_effect=Exception( "this test mocks MSAL, so no request should be sent"))) credential = UsernamePasswordCredential("client-id", "username", "password", transport=transport) with patch.object(UsernamePasswordCredential, "_get_app") as get_mock_app: msal_app = get_mock_app() msal_app.acquire_token_by_username_password.return_value = msal_acquire_token_result credential.authenticate(scopes=["scope"], claims=expected_claims) assert msal_app.acquire_token_by_username_password.call_count == 1 args, kwargs = msal_app.acquire_token_by_username_password.call_args assert kwargs["claims_challenge"] == expected_claims credential.get_token("scope", claims=expected_claims) assert msal_app.acquire_token_by_username_password.call_count == 2 args, kwargs = msal_app.acquire_token_by_username_password.call_args assert kwargs["claims_challenge"] == expected_claims msal_app.get_accounts.return_value = [{ "home_account_id": credential._auth_record.home_account_id }] msal_app.acquire_token_silent_with_error.return_value = msal_acquire_token_result credential.get_token("scope", claims=expected_claims) assert msal_app.acquire_token_silent_with_error.call_count == 1 args, kwargs = msal_app.acquire_token_silent_with_error.call_args assert kwargs["claims_challenge"] == expected_claims
def test_policies_configurable(): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) transport = msal_validating_transport( requests=[Request()], responses=[ mock_response(json_payload=build_aad_response(access_token="**")) ]) credential = CertificateCredential( "tenant-id", "client-id", PEM_CERT_PATH, policies=[ContentDecodePolicy(), policy], transport=transport) credential.get_token("scope") assert policy.on_request.called
def test_policies_configurable(): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) transport = validating_transport( requests=[Request()] * 2, responses=[get_discovery_response(), mock_response(json_payload=build_aad_response(access_token="**"))], ) # mock local server fakes successful authentication by immediately returning a well-formed response oauth_state = "oauth-state" auth_code_response = {"code": "authorization-code", "state": [oauth_state]} server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response)) credential = InteractiveBrowserCredential(policies=[policy], transport=transport, server_class=server_class) with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state): credential.get_token("scope") assert policy.on_request.called
def test_two_accounts_username_specified(): """two cached accounts, username specified, one account matches -> credential should auth that account""" scope = "scope" expected_refresh_token = "refresh-token-a" upn_a = "a@foo" upn_b = "b@foo" account_a = get_account_event(username=upn_a, uid="uid_a", utid="utid", refresh_token=expected_refresh_token) account_b = get_account_event(username=upn_b, uid="uid_b", utid="utid", refresh_token="refresh_token_b") cache = populated_cache(account_a, account_b) expected_token = "***" transport = validating_transport( requests=[Request(required_data={"refresh_token": expected_refresh_token, "scope": scope})], responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(username=upn_a, _cache=cache, transport=transport) token = credential.get_token(scope) assert token.token == expected_token
def test_tenant_id(): client_id = "client-id" expected_token = "access-token" user_code = "user-code" verification_uri = "verification-uri" expires_in = 42 transport = validating_transport( requests=[Request()] * 3, # not validating requests because they're formed by MSAL responses=[ # expected requests: discover tenant, start device code flow, poll for completion mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}), mock_response( json_payload={ "device_code": "_", "user_code": user_code, "verification_uri": verification_uri, "expires_in": expires_in, } ), mock_response( json_payload=dict( build_aad_response( access_token=expected_token, expires_in=expires_in, refresh_token="_", id_token=build_id_token(aud=client_id), ), scope="scope", ), ), ], ) callback = Mock() credential = DeviceCodeCredential( client_id=client_id, prompt_callback=callback, transport=transport, instance_discovery=False, ) now = datetime.datetime.utcnow() token = credential.get_token("scope", tenant_id="tenant_id") assert token.token == expected_token
def test_user_agent(): transport = validating_transport( requests=[Request()] * 2 + [Request(required_headers={"User-Agent": USER_AGENT})], responses=[ get_discovery_response(), mock_response( json_payload={ "device_code": "_", "user_code": "user-code", "verification_uri": "verification-uri", "expires_in": 42, } ), mock_response(json_payload=dict(build_aad_response(access_token="**"), scope="scope")), ], ) credential = DeviceCodeCredential(client_id="client-id", prompt_callback=Mock(), transport=transport) credential.get_token("scope")
def test_claims_challenge(): """get_token and authenticate should pass any claims challenge to MSAL token acquisition APIs""" expected_claims = '{"access_token": {"essential": "true"}' auth_code_response = {"code": "authorization-code", "state": ["..."]} server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response)) msal_acquire_token_result = dict( build_aad_response(access_token="**", id_token=build_id_token()), id_token_claims=id_token_claims("issuer", "subject", "audience", upn="upn"), ) transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent"))) credential = InteractiveBrowserCredential(_server_class=server_class, transport=transport) with patch.object(InteractiveBrowserCredential, "_get_app") as get_mock_app: msal_app = get_mock_app() msal_app.initiate_auth_code_flow.return_value = {"auth_uri": "http://localhost"} msal_app.acquire_token_by_auth_code_flow.return_value = msal_acquire_token_result with patch(WEBBROWSER_OPEN, lambda _: True): credential.authenticate(scopes=["scope"], claims=expected_claims) assert msal_app.acquire_token_by_auth_code_flow.call_count == 1 args, kwargs = msal_app.acquire_token_by_auth_code_flow.call_args assert kwargs["claims_challenge"] == expected_claims with patch(WEBBROWSER_OPEN, lambda _: True): credential.get_token("scope", claims=expected_claims) assert msal_app.acquire_token_by_auth_code_flow.call_count == 2 args, kwargs = msal_app.acquire_token_by_auth_code_flow.call_args assert kwargs["claims_challenge"] == expected_claims msal_app.get_accounts.return_value = [{"home_account_id": credential._auth_record.home_account_id}] msal_app.acquire_token_silent_with_error.return_value = msal_acquire_token_result credential.get_token("scope", claims=expected_claims) assert msal_app.acquire_token_silent_with_error.call_count == 1 args, kwargs = msal_app.acquire_token_silent_with_error.call_args assert kwargs["claims_challenge"] == expected_claims
async def test_reads_cloud_settings(cloud, authority): """the credential should read authority and tenant from VS Code settings when an application doesn't specify them""" expected_tenant = "tenant-id" user_settings = {"azure.cloud": cloud, "azure.tenant": expected_tenant} transport = async_validating_transport( requests=[ Request( base_url="https://{}/{}".format(authority, expected_tenant)) ], responses=[ mock_response(json_payload=build_aad_response(access_token="**")) ], ) credential = get_credential(user_settings, transport=transport) with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): await credential.get_token("scope") assert transport.send.call_count == 1
def get_account_event( username, uid, utid, authority=None, client_id="client-id", refresh_token="refresh-token", scopes=None, **kwargs ): if authority: endpoint = "https://" + "/".join((authority, utid, "path",)) else: endpoint = get_default_authority() + "/{}/{}".format(utid, "path") return { "response": build_aad_response( uid=uid, utid=utid, refresh_token=refresh_token, id_token=build_id_token(aud=client_id, username=username), foci="1", **kwargs ), "client_id": client_id, "token_endpoint": endpoint, "scope": scopes or ["scope"], }
def test_claims_challenge(): """get_token should pass any claims challenge to MSAL token acquisition APIs""" expected_claims = '{"access_token": {"essential": "true"}' record = AuthenticationRecord("tenant-id", "client_id", "authority", "home_account_id", "username") msal_app = Mock() msal_app.get_accounts.return_value = [{"home_account_id": record.home_account_id}] msal_app.acquire_token_silent_with_error.return_value = dict( build_aad_response(access_token="**", id_token=build_id_token()) ) transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent"))) credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache()) with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication", lambda *_, **__: msal_app): credential.get_token("scope", claims=expected_claims) assert msal_app.acquire_token_silent_with_error.call_count == 1 args, kwargs = msal_app.acquire_token_silent_with_error.call_args assert kwargs["claims_challenge"] == expected_claims
def get_account_event( username, uid, utid, authority=None, client_id="client-id", refresh_token="refresh-token", scopes=None, ): return { "response": build_aad_response( uid=uid, utid=utid, refresh_token=refresh_token, id_token=build_id_token(aud=client_id, preferred_username=username), foci="1", ), "client_id": client_id, "token_endpoint": "https://" + "/".join((authority or KnownAuthorities.AZURE_PUBLIC_CLOUD, utid, "/path")), "scope": scopes or ["scope"], }
def test_user_agent(): client_id = "client-id" transport = validating_transport( requests=[Request(), Request(required_headers={"User-Agent": USER_AGENT})], responses=[ get_discovery_response(), mock_response(json_payload=build_aad_response(access_token="**", id_token=build_id_token(aud=client_id))), ], ) # mock local server fakes successful authentication by immediately returning a well-formed response oauth_state = "oauth-state" auth_code_response = {"code": "authorization-code", "state": [oauth_state]} server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response)) credential = InteractiveBrowserCredential( client_id=client_id, transport=transport, _server_class=server_class, _cache=TokenCache() ) with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state): credential.get_token("scope")
def test_home_account_id_no_client_info(): """the credential should use the subject claim as home_account_id when MSAL doesn't provide client_info""" subject = "subject" msal_response = build_aad_response(access_token="***", refresh_token="**") msal_response["id_token_claims"] = { "aud": "client-id", "iss": "https://localhost", "object_id": "some-guid", "tid": "some-tenant", "preferred_username": "******", "sub": subject, } class TestCredential(InteractiveCredential): def __init__(self, **kwargs): super(TestCredential, self).__init__(client_id="...", **kwargs) def _request_token(self, *_, **__): return msal_response record = TestCredential().authenticate() assert record.home_account_id == subject
def test_authentication_record(): tenant_id = "tenant-id" client_id = "client-id" authority = "localhost" object_id = "object-id" home_account_id = object_id + "." + tenant_id username = "******" record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) expected_access_token = "****" expected_refresh_token = "**" account = get_account_event(username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token) cache = populated_cache(account) transport = msal_validating_transport( endpoint="https://{}/{}".format(authority, tenant_id), requests=[ Request(authority=authority, required_data={"refresh_token": expected_refresh_token}) ], responses=[ mock_response(json_payload=build_aad_response( access_token=expected_access_token)) ], ) credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) token = credential.get_token("scope") assert token.token == expected_access_token
def test_home_account_id_client_info(): """when MSAL returns client_info, the credential should decode it to get the home_account_id""" object_id = "object-id" home_tenant = "home-tenant-id" msal_response = build_aad_response(uid=object_id, utid=home_tenant, access_token="***", refresh_token="**") msal_response["id_token_claims"] = { "aud": "client-id", "iss": "https://localhost", "object_id": object_id, "tid": home_tenant, "preferred_username": "******", "sub": "subject", } class TestCredential(InteractiveCredential): def __init__(self, **kwargs): super(TestCredential, self).__init__(client_id="...", **kwargs) def _request_token(self, *_, **__): return msal_response record = TestCredential()._authenticate() assert record.home_account_id == "{}.{}".format(object_id, home_tenant)
async def test_multitenant_cache(): client_id = "client-id" scope = "scope" expected_token = "***" tenant_a = "tenant-a" tenant_b = "tenant-b" tenant_c = "tenant-c" authority = "https://localhost/" + tenant_a cache = TokenCache() cache.add({ "response": build_aad_response(access_token=expected_token), "client_id": client_id, "scope": [scope], "token_endpoint": "/".join((authority, tenant_a, "oauth2/v2.0/token")), }) common_args = dict(authority=authority, cache=cache, client_id=client_id) client_a = AadClient(tenant_id=tenant_a, **common_args) client_b = AadClient(tenant_id=tenant_b, **common_args) # A has a cached token token = client_a.get_cached_access_token([scope]) assert token.token == expected_token # which B shouldn't return assert client_b.get_cached_access_token([scope]) is None # but C allows multitenant auth and should therefore return the token from tenant_a when appropriate client_c = AadClient(tenant_id=tenant_c, **common_args) assert client_c.get_cached_access_token([scope]) is None token = client_c.get_cached_access_token([scope], tenant_id=tenant_a) assert token.token == expected_token
async def send(*_, **__): return mock_response(json_payload=build_aad_response( access_token="**"))
def test_writes_to_cache(): """the credential should write tokens it acquires to the cache""" scope = "scope" expected_access_token = "access token" first_refresh_token = "first refresh token" second_refresh_token = "second refresh token" username = "******" uid = "uid" utid = "utid" account = get_account_event(username=username, uid=uid, utid=utid, refresh_token=first_refresh_token) cache = TokenCache() cache.add(account) transport = validating_transport( requests=[ Request(required_data={"refresh_token": first_refresh_token}) ], # credential redeems refresh token responses=[ mock_response( json_payload= build_aad_response( # AAD responds with an access token and new refresh token uid=uid, utid=utid, access_token=expected_access_token, refresh_token=second_refresh_token, id_token=build_id_token(aud=DEVELOPER_SIGN_ON_CLIENT_ID, object_id=uid, tenant_id=utid, username=username), )) ], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport) token = credential.get_token(scope) assert token.token == expected_access_token # access token should be in the cache, and another instance should retrieve it credential = SharedTokenCacheCredential( _cache=cache, transport=Mock(send=Mock(side_effect=Exception( "the credential should return a cached token")))) token = credential.get_token(scope) assert token.token == expected_access_token # and the credential should have updated the cached refresh token second_access_token = "second access token" transport = validating_transport( requests=[ Request(required_data={"refresh_token": second_refresh_token}) ], responses=[ mock_response(json_payload=build_aad_response( access_token=second_access_token)) ], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport) token = credential.get_token("some other " + scope) assert token.token == second_access_token # verify the credential didn't add a new cache entry assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 1
def test_same_username_different_tenants(): """two cached accounts, same username, different tenants""" access_token_a = "access-token-a" access_token_b = "access-token-b" refresh_token_a = "refresh-token-a" refresh_token_b = "refresh-token-b" upn = "spam@eggs" tenant_a = "tenant-a" tenant_b = "tenant-b" account_a = get_account_event(username=upn, uid="another-guid", utid=tenant_a, refresh_token=refresh_token_a) account_b = get_account_event(username=upn, uid="more-guid", utid=tenant_b, refresh_token=refresh_token_b) cache = populated_cache(account_a, account_b) # with no tenant specified the credential can't select an identity transport = Mock( side_effect=Exception()) # (so it shouldn't use the network) credential = SharedTokenCacheCredential(username=upn, _cache=cache, transport=transport) with pytest.raises(ClientAuthenticationError) as ex: credential.get_token("scope") # error message should indicate multiple matching accounts, and list discovered accounts assert ex.value.message.startswith( MULTIPLE_MATCHING_ACCOUNTS[:MULTIPLE_MATCHING_ACCOUNTS.index("{")]) discovered_accounts = ex.value.message.splitlines()[-1] assert discovered_accounts.count(upn) == 2 assert tenant_a in discovered_accounts and tenant_b in discovered_accounts # with tenant specified, the credential should auth the matching account scope = "scope" transport = validating_transport( requests=[ Request(required_data={ "refresh_token": refresh_token_a, "scope": scope }) ], responses=[ mock_response(json_payload=build_aad_response( access_token=access_token_a)) ], ) credential = SharedTokenCacheCredential(tenant_id=tenant_a, _cache=cache, transport=transport) token = credential.get_token(scope) assert token.token == access_token_a transport = validating_transport( requests=[ Request(required_data={ "refresh_token": refresh_token_b, "scope": scope }) ], responses=[ mock_response(json_payload=build_aad_response( access_token=access_token_b)) ], ) credential = SharedTokenCacheCredential(tenant_id=tenant_b, _cache=cache, transport=transport) token = credential.get_token(scope) assert token.token == access_token_b
def test_same_tenant_different_usernames(): """two cached accounts, same tenant, different usernames""" access_token_a = "access-token-a" access_token_b = "access-token-b" refresh_token_a = "refresh-token-a" refresh_token_b = "refresh-token-b" upn_a = "spam@eggs" upn_b = "eggs@spam" tenant_id = "the-tenant" account_a = get_account_event(username=upn_a, uid="another-guid", utid=tenant_id, refresh_token=refresh_token_a) account_b = get_account_event(username=upn_b, uid="more-guid", utid=tenant_id, refresh_token=refresh_token_b) cache = populated_cache(account_a, account_b) # with no username specified the credential can't select an identity transport = Mock( side_effect=Exception()) # (so it shouldn't use the network) credential = SharedTokenCacheCredential(tenant_id=tenant_id, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: credential.get_token("scope") assert ex.value.message.startswith( MULTIPLE_MATCHING_ACCOUNTS[:MULTIPLE_MATCHING_ACCOUNTS.index("{")]) assert tenant_id in ex.value.message # with a username specified, the credential should auth the matching account scope = "scope" transport = validating_transport( requests=[ Request(required_data={ "refresh_token": refresh_token_b, "scope": scope }) ], responses=[ mock_response(json_payload=build_aad_response( access_token=access_token_a)) ], ) credential = SharedTokenCacheCredential(username=upn_b, _cache=cache, transport=transport) token = credential.get_token(scope) assert token.token == access_token_a transport = validating_transport( requests=[ Request(required_data={ "refresh_token": refresh_token_a, "scope": scope }) ], responses=[ mock_response(json_payload=build_aad_response( access_token=access_token_a)) ], ) credential = SharedTokenCacheCredential(username=upn_a, _cache=cache, transport=transport) token = credential.get_token(scope) assert token.token == access_token_a
def mock_send(request, **_): if not request.body: return get_discovery_response() assert request.url.startswith("https://localhost/" + expected_tenant_id) return mock_response(json_payload=build_aad_response(access_token="*"))
def test_interactive_credential(mock_open): mock_open.side_effect = _validate_auth_request_url oauth_state = "state" client_id = "client-id" expected_refresh_token = "refresh-token" expected_token = "access-token" expires_in = 3600 authority = "authority" tenant_id = "tenant_id" endpoint = "https://{}/{}".format(authority, tenant_id) discovery_response = mock_response( json_payload={ name: endpoint for name in ("authorization_endpoint", "token_endpoint", "tenant_discovery_endpoint") }) transport = validating_transport( requests=[Request(url_substring=endpoint)] * 3 + [ Request(authority=authority, url_substring=endpoint, required_data={"refresh_token": expected_refresh_token}) ], responses=[ discovery_response, # instance discovery discovery_response, # tenant discovery mock_response(json_payload=build_aad_response( access_token=expected_token, expires_in=expires_in, refresh_token=expected_refresh_token, uid="uid", utid="utid", token_type="Bearer", )), mock_response( json_payload=build_aad_response(access_token=expected_token, expires_in=expires_in, token_type="Bearer")), ], ) # mock local server fakes successful authentication by immediately returning a well-formed response auth_code_response = {"code": "authorization-code", "state": [oauth_state]} server_class = Mock(return_value=Mock( wait_for_redirect=lambda: auth_code_response)) credential = InteractiveBrowserCredential( authority=authority, tenant_id=tenant_id, client_id=client_id, client_secret="secret", server_class=server_class, transport=transport, instance_discovery=False, validate_authority=False, ) # The credential's auth code request includes a uuid which must be included in the redirect. Patching to # set the uuid requires less code here than a proper mock server. with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state): token = credential.get_token("scope") assert token.token == expected_token assert mock_open.call_count == 1 # token should be cached, get_token shouldn't prompt again token = credential.get_token("scope") assert token.token == expected_token assert mock_open.call_count == 1 # As of MSAL 1.0.0, applications build a new client every time they redeem a refresh token. # Here we patch the private method they use for the sake of test coverage. # TODO: this will probably break when this MSAL behavior changes app = credential._get_app() app._build_client = lambda *_: app.client # pylint:disable=protected-access now = time.time() # expired access token -> credential should use refresh token instead of prompting again with patch("time.time", lambda: now + expires_in): token = credential.get_token("scope") assert token.token == expected_token assert mock_open.call_count == 1 # ensure all expected requests were sent assert transport.send.call_count == 4