def test_interactive_credential(): oauth_state = "state" expected_token = "access-token" transport = validating_transport( requests=[Request()] * 2, # not validating requests because they're formed by MSAL responses=[ # expecting tenant discovery then a token request mock_response( json_payload={ "authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b" }), mock_response( json_payload={ "access_token": expected_token, "expires_in": 42, "token_type": "Bearer", "ext_expires_in": 42, }), ], ) # mock local server fakes successful authentication by immediately returning a well-formed response auth_code_response = {"code": "authorization-code", "state": [oauth_state]} server_class = Mock(return_value=Mock( wait_for_redirect=lambda: auth_code_response)) credential = InteractiveBrowserCredential( client_id="guid", client_secret="secret", server_class=server_class, transport=transport, instance_discovery= False, # kwargs are passed to MSAL; this one prevents an AAD verification request ) # ensure the request beginning the flow has a known state value with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state): token = credential.get_token("scope") assert token.token == expected_token
async def test_identity_config_cloud_shell(): param_name, param_value = "foo", "bar" access_token = "****" expires_on = 42 expected_token = AccessToken(access_token, expires_on) endpoint = "http://localhost:42/token" scope = "scope" transport = async_validating_transport( requests=[ Request( base_url=endpoint, method="POST", required_headers={ "Metadata": "true", "User-Agent": USER_AGENT }, required_data={ "resource": scope, param_name: param_value }, ) ], responses=[ mock_response( json_payload={ "access_token": access_token, "expires_in": 0, "expires_on": expires_on, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", }) ], ) with mock.patch.dict(MsiCredential.__module__ + ".os.environ", {EnvironmentVariables.MSI_ENDPOINT: endpoint}, clear=True): credential = MsiCredential(_identity_config={param_name: param_value}, transport=transport) token = await credential.get_token(scope) assert token == expected_token
def test_two_accounts_username_specified(): """two cached accounts, username specified, one account matches -> credential should auth that account""" scope = "scope" expected_refresh_token = "refresh-token-a" upn_a = "a@foo" upn_b = "b@foo" account_a = get_account_event(username=upn_a, uid="uid_a", utid="utid", refresh_token=expected_refresh_token) account_b = get_account_event(username=upn_b, uid="uid_b", utid="utid", refresh_token="refresh_token_b") cache = populated_cache(account_a, account_b) expected_token = "***" transport = validating_transport( requests=[Request(required_data={"refresh_token": expected_refresh_token, "scope": scope})], responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(username=upn_a, _cache=cache, transport=transport) token = credential.get_token(scope) assert token.token == expected_token
def test_policies_configurable(): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) transport = validating_transport( requests=[Request()] * 3, responses=[get_discovery_response()] * 2 + [ mock_response(json_payload=build_aad_response( access_token="**", id_token=build_id_token())) ], ) credential = UsernamePasswordCredential("client-id", "username", "password", policies=[policy], transport=transport) credential.get_token("scope") assert policy.on_request.called
def test_device_code_credential(): expected_token = "access-token" user_code = "user-code" verification_uri = "verification-uri" expires_in = 42 transport = validating_transport( requests=[Request()] * 3, # not validating requests because they're formed by MSAL responses=[ # expected requests: discover tenant, start device code flow, poll for completion mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}), mock_response( json_payload={ "device_code": "_", "user_code": user_code, "verification_uri": verification_uri, "expires_in": expires_in, } ), mock_response( json_payload={ "access_token": expected_token, "expires_in": expires_in, "scope": "scope", "token_type": "Bearer", "refresh_token": "_", } ), ], ) callback = Mock() credential = DeviceCodeCredential( client_id="_", prompt_callback=callback, transport=transport, instance_discovery=False ) token = credential.get_token("scope") assert token.token == expected_token # prompt_callback should have been called as documented assert callback.call_count == 1 assert callback.call_args[0] == (verification_uri, user_code, expires_in)
def test_tenant_id(): client_id = "client-id" expected_token = "access-token" user_code = "user-code" verification_uri = "verification-uri" expires_in = 42 transport = validating_transport( requests=[Request()] * 3, # not validating requests because they're formed by MSAL responses=[ # expected requests: discover tenant, start device code flow, poll for completion mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}), mock_response( json_payload={ "device_code": "_", "user_code": user_code, "verification_uri": verification_uri, "expires_in": expires_in, } ), mock_response( json_payload=dict( build_aad_response( access_token=expected_token, expires_in=expires_in, refresh_token="_", id_token=build_id_token(aud=client_id), ), scope="scope", ), ), ], ) callback = Mock() credential = DeviceCodeCredential( client_id=client_id, prompt_callback=callback, transport=transport, instance_discovery=False, ) now = datetime.datetime.utcnow() token = credential.get_token("scope", tenant_id="tenant_id") assert token.token == expected_token
def test_custom_hooks(environ): """The credential's pipeline should include azure-core's CustomHookPolicy""" scope = "scope" expected_token = "***" request_hook = mock.Mock() response_hook = mock.Mock() now = int(time.time()) expected_response = mock_response( json_payload={ "access_token": expected_token, "expires_in": 3600, "expires_on": now + 3600, "ext_expires_in": 3600, "not_before": now, "resource": scope, "token_type": "Bearer", }) transport = validating_transport(requests=[Request()] * 2, responses=[expected_response] * 2) with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, environ, clear=True): credential = ManagedIdentityCredential(transport=transport, raw_request_hook=request_hook, raw_response_hook=response_hook) credential.get_token(scope) if environ: # some environment variables are set, so we're not mocking IMDS and should expect 1 request assert request_hook.call_count == 1 assert response_hook.call_count == 1 args, kwargs = response_hook.call_args pipeline_response = args[0] assert pipeline_response.http_response == expected_response else: # we're mocking IMDS and should expect 2 requests assert request_hook.call_count == 2 assert response_hook.call_count == 2 responses = [ args[0].http_response for args, _ in response_hook.call_args_list ] assert responses == [expected_response] * 2
def test_cloud_shell_user_assigned_identity(): """Cloud Shell environment: only MSI_ENDPOINT set""" access_token = "****" expires_on = 42 client_id = "some-guid" expected_token = AccessToken(access_token, expires_on) endpoint = "http://localhost:42/token" scope = "scope" transport = validating_transport( requests=[ Request( base_url=endpoint, method="POST", required_headers={ "Metadata": "true", "User-Agent": USER_AGENT }, required_data={ "client_id": client_id, "resource": scope }, ) ], responses=[ mock_response( json_payload={ "access_token": access_token, "expires_in": 0, "expires_on": expires_on, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", }) ], ) with mock.patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: endpoint}): token = ManagedIdentityCredential(client_id=client_id, transport=transport).get_token(scope) assert token == expected_token
def test_app_service(): """App Service environment: MSI_ENDPOINT, MSI_SECRET set""" access_token = "****" expires_on = 42 expected_token = AccessToken(access_token, expires_on) url = "http://localhost:42/token" secret = "expected-secret" scope = "scope" transport = validating_transport( requests=[ Request( url, method="GET", required_headers={ "Metadata": "true", "secret": secret }, required_params={ "api-version": "2017-09-01", "resource": scope }, ) ], responses=[ mock_response( json_payload={ "access_token": access_token, "expires_on": expires_on, "resource": scope, "token_type": "Bearer", }) ], ) with mock.patch( "os.environ", { EnvironmentVariables.MSI_ENDPOINT: url, EnvironmentVariables.MSI_SECRET: secret }): token = ManagedIdentityCredential(transport=transport).get_token(scope) assert token == expected_token
def test_policies_configurable(): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) transport = msal_validating_transport( requests=[Request()], responses=[ mock_response(json_payload=build_aad_response(access_token="**")) ], ) credential = ClientSecretCredential( "tenant-id", "client-id", "client-secret", policies=[ContentDecodePolicy(), policy], transport=transport) credential.get_token("scope") assert policy.on_request.called
def test_username_password_environment_credential(monkeypatch): client_id = "fake-client-id" username = "******" password = "******" expected_token = "***" create_transport = functools.partial( validating_transport, requests=[Request()] * 3, # not validating requests because they're formed by MSAL responses=[ # tenant discovery mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}), # user realm discovery, interests MSAL only when the response body contains account_type == "Federated" mock_response(json_payload={}), # token request mock_response( json_payload={ "access_token": expected_token, "expires_in": 42, "token_type": "Bearer", "ext_expires_in": 42, } ), ], ) monkeypatch.setenv(EnvironmentVariables.AZURE_CLIENT_ID, client_id) monkeypatch.setenv(EnvironmentVariables.AZURE_USERNAME, username) monkeypatch.setenv(EnvironmentVariables.AZURE_PASSWORD, password) token = EnvironmentCredential(transport=create_transport()).get_token("scope") # not validating expires_on because doing so requires monkeypatching time, and this is tested elsewhere assert token.token == expected_token # now with a tenant id monkeypatch.setenv(EnvironmentVariables.AZURE_TENANT_ID, "tenant_id") token = EnvironmentCredential(transport=create_transport()).get_token("scope") assert token.token == expected_token
def test_policies_configurable(): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) transport = validating_transport( requests=[Request()] * 2, responses=[get_discovery_response(), mock_response(json_payload=build_aad_response(access_token="**"))], ) # mock local server fakes successful authentication by immediately returning a well-formed response oauth_state = "oauth-state" auth_code_response = {"code": "authorization-code", "state": [oauth_state]} server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response)) credential = InteractiveBrowserCredential( policies=[policy], transport=transport, server_class=server_class, _cache=TokenCache() ) with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state): credential.get_token("scope") assert policy.on_request.called
async def test_reads_cloud_settings(cloud, authority): """the credential should read authority and tenant from VS Code settings when an application doesn't specify them""" expected_tenant = "tenant-id" user_settings = {"azure.cloud": cloud, "azure.tenant": expected_tenant} transport = async_validating_transport( requests=[ Request( base_url="https://{}/{}".format(authority, expected_tenant)) ], responses=[ mock_response(json_payload=build_aad_response(access_token="**")) ], ) credential = get_credential(user_settings, transport=transport) with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): await credential.get_token("scope") assert transport.send.call_count == 1
def get_credential_for_shared_cache_test(expected_refresh_token, expected_access_token, cache, **kwargs): exclude_other_credentials = { option: True for option in ( "exclude_cli_credential", "exclude_environment_credential", "exclude_managed_identity_credential", "exclude_powershell_credential", ) } options = dict(exclude_other_credentials, **kwargs) # validating transport will raise if the shared cache credential isn't used, or selects the wrong refresh token transport = validating_transport( requests=[Request(required_data={"refresh_token": expected_refresh_token})], responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], ) # this credential uses a mock shared cache, so it works on all platforms with patch.object(SharedTokenCacheCredential, "supported"): return DefaultAzureCredential(_cache=cache, transport=transport, **options)
def test_prefers_app_service_2019_08_01(): """When the environment is configured for both App Service versions, the credential should prefer the most recent""" access_token = "****" expires_on = 42 endpoint = "http://localhost:42/token" secret = "expected-secret" scope = "scope" transport = validating_transport( requests=[ Request( base_url=endpoint, method="GET", required_headers={"X-IDENTITY-HEADER": secret, "User-Agent": USER_AGENT}, required_params={"api-version": "2019-08-01", "resource": scope}, ) ], responses=[ mock_response( json_payload={ "access_token": access_token, "expires_on": str(expires_on), "resource": scope, "token_type": "Bearer", } ) ], ) environ = { EnvironmentVariables.IDENTITY_ENDPOINT: endpoint, EnvironmentVariables.IDENTITY_HEADER: secret, EnvironmentVariables.MSI_ENDPOINT: endpoint, EnvironmentVariables.MSI_SECRET: secret, } with mock.patch.dict("os.environ", environ, clear=True): token = ManagedIdentityCredential(transport=transport).get_token(scope) assert token.token == access_token assert token.expires_on == expires_on
def test_identity_config_app_service(): param_name, param_value = "foo", "bar" access_token = "****" expires_on = 42 expected_token = AccessToken(access_token, expires_on) endpoint = "http://localhost:42/token" secret = "expected-secret" scope = "scope" transport = validating_transport( requests=[ Request( base_url=endpoint, method="GET", required_headers={"Metadata": "true", "secret": secret, "User-Agent": USER_AGENT}, required_params={"api-version": "2017-09-01", "resource": scope, param_name: param_value,}, ) ], responses=[ mock_response( json_payload={ "access_token": access_token, "expires_on": expires_on, "resource": scope, "token_type": "Bearer", } ) ], ) with mock.patch.dict( MsiCredential.__module__ + ".os.environ", {EnvironmentVariables.MSI_ENDPOINT: endpoint, EnvironmentVariables.MSI_SECRET: secret}, clear=True, ): credential = MsiCredential(identity_config={param_name: param_value}, transport=transport) token = credential.get_token(scope) assert token == expected_token
def test_timeout(): transport = validating_transport( requests=[Request()] * 3, # not validating requests because they're formed by MSAL responses=[ # expected requests: discover tenant, start device code flow, poll for completion mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}), mock_response(json_payload={"device_code": "_", "user_code": "_", "verification_uri": "_"}), mock_response(json_payload={"error": "authorization_pending"}), ], ) credential = DeviceCodeCredential( client_id="_", prompt_callback=Mock(), transport=transport, timeout=0.01, instance_discovery=False, _cache=TokenCache(), ) with pytest.raises(ClientAuthenticationError) as ex: credential.get_token("scope") assert "timed out" in ex.value.message.lower()
def test_get_error_response(): first_result = {"error": "first"} first_response = mock_response(401, json_payload=first_result) second_result = {"error": "second"} second_response = mock_response(401, json_payload=second_result) transport = validating_transport( requests=[Request(url="https://localhost")] * 2, responses=[first_response, second_response] ) client = MsalClient(transport=transport) for result in (first_result, second_result): assert not client.get_error_response(result) client.get("https://localhost") response = client.get_error_response(first_result) assert response is first_response client.post("https://localhost") response = client.get_error_response(second_result) assert response is second_response assert not client.get_error_response(first_result)
def test_authentication_record(): tenant_id = "tenant-id" client_id = "client-id" authority = "localhost" object_id = "object-id" home_account_id = object_id + "." + tenant_id username = "******" record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) expected_access_token = "****" expected_refresh_token = "**" account = get_account_event(username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token) cache = populated_cache(account) transport = msal_validating_transport( endpoint="https://{}/{}".format(authority, tenant_id), requests=[ Request(authority=authority, required_data={"refresh_token": expected_refresh_token}) ], responses=[ mock_response(json_payload=build_aad_response( access_token=expected_access_token)) ], ) credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) token = credential.get_token("scope") assert token.token == expected_access_token
async def test_cloud_shell(): """Cloud Shell environment: only MSI_ENDPOINT set""" access_token = "****" expires_on = 42 expected_token = AccessToken(access_token, expires_on) url = "http://localhost:42/token" scope = "scope" transport = async_validating_transport( requests=[ Request( url, method="POST", required_headers={ "Metadata": "true", "User-Agent": USER_AGENT }, required_data={"resource": scope}, ) ], responses=[ mock_response( json_payload={ "access_token": access_token, "expires_in": 0, "expires_on": expires_on, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", }) ], ) with mock.patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: url}): token = await ManagedIdentityCredential(transport=transport ).get_token(scope) assert token == expected_token
def test_client_secret_credential(): client_id = "fake-client-id" secret = "fake-client-secret" tenant_id = "fake-tenant-id" access_token = "***" transport = validating_transport( requests=[Request(url_substring=tenant_id, required_data={"client_id": client_id, "client_secret": secret})], responses=[ mock_response( json_payload={ "token_type": "Bearer", "expires_in": 42, "ext_expires_in": 42, "access_token": access_token, } ) ], ) token = ClientSecretCredential(tenant_id, client_id, secret, transport=transport).get_token("scope") # not validating expires_on because doing so requires monkeypatching time, and this is tested elsewhere assert token.token == access_token
def test_app_service_user_assigned_identity(): """App Service 2017-09-01: MSI_ENDPOINT, MSI_SECRET set""" access_token = "****" expires_on = 42 client_id = "some-guid" expected_token = AccessToken(access_token, expires_on) endpoint = "http://localhost:42/token" secret = "expected-secret" scope = "scope" transport = validating_transport( requests=[ Request( base_url=endpoint, method="GET", required_headers={"secret": secret, "User-Agent": USER_AGENT}, required_params={"api-version": "2017-09-01", "clientid": client_id, "resource": scope}, ) ], responses=[ mock_response( json_payload={ "access_token": access_token, "expires_on": "01/01/1970 00:00:{} +00:00".format(expires_on), "resource": scope, "token_type": "Bearer", } ) ], ) with mock.patch( "os.environ", {EnvironmentVariables.MSI_ENDPOINT: endpoint, EnvironmentVariables.MSI_SECRET: secret} ): token = ManagedIdentityCredential(client_id=client_id, transport=transport).get_token(scope) assert token == expected_token
async def test_client_secret_environment_credential(): client_id = "fake-client-id" secret = "fake-client-secret" tenant_id = "fake-tenant-id" access_token = "***" transport = async_validating_transport( requests=[ Request(url_substring=tenant_id, required_data={ "client_id": client_id, "client_secret": secret }) ], responses=[ mock_response( json_payload={ "token_type": "Bearer", "expires_in": 42, "ext_expires_in": 42, "access_token": access_token, }) ], ) environment = { EnvironmentVariables.AZURE_CLIENT_ID: client_id, EnvironmentVariables.AZURE_CLIENT_SECRET: secret, EnvironmentVariables.AZURE_TENANT_ID: tenant_id, } with patch("os.environ", environment): token = await EnvironmentCredential(transport=transport ).get_token("scope") # not validating expires_on because doing so requires monkeypatching time, and this is tested elsewhere assert token.token == access_token
def test_username_password_credential(): expected_token = "access-token" transport = validating_transport( requests=[Request()] * 3, # not validating requests because they're formed by MSAL responses=[ # tenant discovery mock_response( json_payload={ "authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b" }), # user realm discovery, interests MSAL only when the response body contains account_type == "Federated" mock_response(json_payload={}), # token request mock_response( json_payload={ "access_token": expected_token, "expires_in": 42, "token_type": "Bearer", "ext_expires_in": 42, }), ], ) credential = UsernamePasswordCredential( client_id="some-guid", username="******", password="******", transport=transport, instance_discovery= False, # kwargs are passed to MSAL; this one prevents an AAD verification request ) token = credential.get_token("scope") assert token.token == expected_token
def test_writes_to_cache(): """the credential should write tokens it acquires to the cache""" scope = "scope" expected_access_token = "access token" first_refresh_token = "first refresh token" second_refresh_token = "second refresh token" username = "******" uid = "uid" utid = "utid" account = get_account_event(username=username, uid=uid, utid=utid, refresh_token=first_refresh_token) cache = TokenCache() cache.add(account) transport = validating_transport( requests=[ Request(required_data={"refresh_token": first_refresh_token}) ], # credential redeems refresh token responses=[ mock_response( json_payload= build_aad_response( # AAD responds with an access token and new refresh token uid=uid, utid=utid, access_token=expected_access_token, refresh_token=second_refresh_token, id_token=build_id_token(aud=DEVELOPER_SIGN_ON_CLIENT_ID, object_id=uid, tenant_id=utid, username=username), )) ], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport) token = credential.get_token(scope) assert token.token == expected_access_token # access token should be in the cache, and another instance should retrieve it credential = SharedTokenCacheCredential( _cache=cache, transport=Mock(send=Mock(side_effect=Exception( "the credential should return a cached token")))) token = credential.get_token(scope) assert token.token == expected_access_token # and the credential should have updated the cached refresh token second_access_token = "second access token" transport = validating_transport( requests=[ Request(required_data={"refresh_token": second_refresh_token}) ], responses=[ mock_response(json_payload=build_aad_response( access_token=second_access_token)) ], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport) token = credential.get_token("some other " + scope) assert token.token == second_access_token # verify the credential didn't add a new cache entry assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 1
def test_interactive_credential(mock_open, redirect_url): mock_open.side_effect = _validate_auth_request_url oauth_state = "state" client_id = "client-id" expected_refresh_token = "refresh-token" expected_token = "access-token" expires_in = 3600 authority = "authority" tenant_id = "tenant-id" endpoint = "https://{}/{}".format(authority, tenant_id) transport = msal_validating_transport( endpoint="https://{}/{}".format(authority, tenant_id), requests=[Request(url_substring=endpoint)] + [ Request(authority=authority, url_substring=endpoint, required_data={"refresh_token": expected_refresh_token}) ], responses=[ mock_response(json_payload=build_aad_response( access_token=expected_token, expires_in=expires_in, refresh_token=expected_refresh_token, uid="uid", utid=tenant_id, id_token=build_id_token(aud=client_id, object_id="uid", tenant_id=tenant_id, iss=endpoint), token_type="Bearer", )), mock_response( json_payload=build_aad_response(access_token=expected_token, expires_in=expires_in, token_type="Bearer")), ], ) # mock local server fakes successful authentication by immediately returning a well-formed response auth_code_response = {"code": "authorization-code", "state": [oauth_state]} server_class = Mock(return_value=Mock( wait_for_redirect=lambda: auth_code_response)) args = { "authority": authority, "tenant_id": tenant_id, "client_id": client_id, "transport": transport, "_cache": TokenCache(), "_server_class": server_class, } if redirect_url: # avoid passing redirect_url=None args["redirect_uri"] = redirect_url credential = InteractiveBrowserCredential(**args) # The credential's auth code request includes a uuid which must be included in the redirect. Patching to # set the uuid requires less code here than a proper mock server. with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state): token = credential.get_token("scope") assert token.token == expected_token assert mock_open.call_count == 1 assert server_class.call_count == 1 if redirect_url: server_class.assert_called_once_with(redirect_url, timeout=ANY) # token should be cached, get_token shouldn't prompt again token = credential.get_token("scope") assert token.token == expected_token assert mock_open.call_count == 1 assert server_class.call_count == 1 # expired access token -> credential should use refresh token instead of prompting again now = time.time() with patch("time.time", lambda: now + expires_in): token = credential.get_token("scope") assert token.token == expected_token assert mock_open.call_count == 1 # ensure all expected requests were sent assert transport.send.call_count == 4
def test_same_tenant_different_usernames(): """two cached accounts, same tenant, different usernames""" access_token_a = "access-token-a" access_token_b = "access-token-b" refresh_token_a = "refresh-token-a" refresh_token_b = "refresh-token-b" upn_a = "spam@eggs" upn_b = "eggs@spam" tenant_id = "the-tenant" account_a = get_account_event(username=upn_a, uid="another-guid", utid=tenant_id, refresh_token=refresh_token_a) account_b = get_account_event(username=upn_b, uid="more-guid", utid=tenant_id, refresh_token=refresh_token_b) cache = populated_cache(account_a, account_b) # with no username specified the credential can't select an identity transport = Mock( side_effect=Exception()) # (so it shouldn't use the network) credential = SharedTokenCacheCredential(tenant_id=tenant_id, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: credential.get_token("scope") assert ex.value.message.startswith( MULTIPLE_MATCHING_ACCOUNTS[:MULTIPLE_MATCHING_ACCOUNTS.index("{")]) assert tenant_id in ex.value.message # with a username specified, the credential should auth the matching account scope = "scope" transport = validating_transport( requests=[ Request(required_data={ "refresh_token": refresh_token_b, "scope": scope }) ], responses=[ mock_response(json_payload=build_aad_response( access_token=access_token_a)) ], ) credential = SharedTokenCacheCredential(username=upn_b, _cache=cache, transport=transport) token = credential.get_token(scope) assert token.token == access_token_a transport = validating_transport( requests=[ Request(required_data={ "refresh_token": refresh_token_a, "scope": scope }) ], responses=[ mock_response(json_payload=build_aad_response( access_token=access_token_a)) ], ) credential = SharedTokenCacheCredential(username=upn_a, _cache=cache, transport=transport) token = credential.get_token(scope) assert token.token == access_token_a
def test_authenticate(): client_id = "client-id" environment = "localhost" issuer = "https://" + environment tenant_id = "some-tenant" authority = issuer + "/" + tenant_id access_token = "***" scope = "scope" # mock AAD response with id token object_id = "object-id" home_tenant = "home-tenant-id" username = "******" id_token = build_id_token(aud=client_id, iss=issuer, object_id=object_id, tenant_id=home_tenant, username=username) auth_response = build_aad_response(uid=object_id, utid=home_tenant, access_token=access_token, refresh_token="**", id_token=id_token) transport = validating_transport( requests=[Request(url_substring=issuer)] * 3, responses=[get_discovery_response(authority)] * 2 + [mock_response(json_payload=auth_response)], ) # mock local server fakes successful authentication by immediately returning a well-formed response oauth_state = "state" auth_code_response = {"code": "authorization-code", "state": [oauth_state]} server_class = Mock(return_value=Mock( wait_for_redirect=lambda: auth_code_response)) with patch(InteractiveBrowserCredential.__module__ + ".uuid.uuid4", lambda: oauth_state): with patch(WEBBROWSER_OPEN, lambda _: True): credential = InteractiveBrowserCredential( _cache=TokenCache(), authority=environment, client_id=client_id, _server_class=server_class, tenant_id=tenant_id, transport=transport, ) record = credential._authenticate(scopes=(scope, )) assert record.authority == environment assert record.home_account_id == object_id + "." + home_tenant assert record.tenant_id == home_tenant assert record.username == username # credential should have a cached access token for the scope used in authenticate with patch( WEBBROWSER_OPEN, Mock(side_effect=Exception( "credential should authenticate silently"))): token = credential.get_token(scope) assert token.token == access_token
def test_interactive_credential(mock_open): mock_open.side_effect = _validate_auth_request_url oauth_state = "state" client_id = "client-id" expected_refresh_token = "refresh-token" expected_token = "access-token" expires_in = 3600 authority = "authority" tenant_id = "tenant_id" endpoint = "https://{}/{}".format(authority, tenant_id) discovery_response = mock_response( json_payload={ name: endpoint for name in ("authorization_endpoint", "token_endpoint", "tenant_discovery_endpoint") }) transport = validating_transport( requests=[Request(url_substring=endpoint)] * 3 + [ Request(authority=authority, url_substring=endpoint, required_data={"refresh_token": expected_refresh_token}) ], responses=[ discovery_response, # instance discovery discovery_response, # tenant discovery mock_response(json_payload=build_aad_response( access_token=expected_token, expires_in=expires_in, refresh_token=expected_refresh_token, uid="uid", utid="utid", token_type="Bearer", )), mock_response( json_payload=build_aad_response(access_token=expected_token, expires_in=expires_in, token_type="Bearer")), ], ) # mock local server fakes successful authentication by immediately returning a well-formed response auth_code_response = {"code": "authorization-code", "state": [oauth_state]} server_class = Mock(return_value=Mock( wait_for_redirect=lambda: auth_code_response)) credential = InteractiveBrowserCredential( authority=authority, tenant_id=tenant_id, client_id=client_id, client_secret="secret", server_class=server_class, transport=transport, instance_discovery=False, validate_authority=False, ) # The credential's auth code request includes a uuid which must be included in the redirect. Patching to # set the uuid requires less code here than a proper mock server. with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state): token = credential.get_token("scope") assert token.token == expected_token assert mock_open.call_count == 1 # token should be cached, get_token shouldn't prompt again token = credential.get_token("scope") assert token.token == expected_token assert mock_open.call_count == 1 # As of MSAL 1.0.0, applications build a new client every time they redeem a refresh token. # Here we patch the private method they use for the sake of test coverage. # TODO: this will probably break when this MSAL behavior changes app = credential._get_app() app._build_client = lambda *_: app.client # pylint:disable=protected-access now = time.time() # expired access token -> credential should use refresh token instead of prompting again with patch("time.time", lambda: now + expires_in): token = credential.get_token("scope") assert token.token == expected_token assert mock_open.call_count == 1 # ensure all expected requests were sent assert transport.send.call_count == 4
def test_same_username_different_tenants(): """two cached accounts, same username, different tenants""" access_token_a = "access-token-a" access_token_b = "access-token-b" refresh_token_a = "refresh-token-a" refresh_token_b = "refresh-token-b" upn = "spam@eggs" tenant_a = "tenant-a" tenant_b = "tenant-b" account_a = get_account_event(username=upn, uid="another-guid", utid=tenant_a, refresh_token=refresh_token_a) account_b = get_account_event(username=upn, uid="more-guid", utid=tenant_b, refresh_token=refresh_token_b) cache = populated_cache(account_a, account_b) # with no tenant specified the credential can't select an identity transport = Mock( side_effect=Exception()) # (so it shouldn't use the network) credential = SharedTokenCacheCredential(username=upn, _cache=cache, transport=transport) with pytest.raises(ClientAuthenticationError) as ex: credential.get_token("scope") # error message should indicate multiple matching accounts, and list discovered accounts assert ex.value.message.startswith( MULTIPLE_MATCHING_ACCOUNTS[:MULTIPLE_MATCHING_ACCOUNTS.index("{")]) discovered_accounts = ex.value.message.splitlines()[-1] assert discovered_accounts.count(upn) == 2 assert tenant_a in discovered_accounts and tenant_b in discovered_accounts # with tenant specified, the credential should auth the matching account scope = "scope" transport = validating_transport( requests=[ Request(required_data={ "refresh_token": refresh_token_a, "scope": scope }) ], responses=[ mock_response(json_payload=build_aad_response( access_token=access_token_a)) ], ) credential = SharedTokenCacheCredential(tenant_id=tenant_a, _cache=cache, transport=transport) token = credential.get_token(scope) assert token.token == access_token_a transport = validating_transport( requests=[ Request(required_data={ "refresh_token": refresh_token_b, "scope": scope }) ], responses=[ mock_response(json_payload=build_aad_response( access_token=access_token_b)) ], ) credential = SharedTokenCacheCredential(tenant_id=tenant_b, _cache=cache, transport=transport) token = credential.get_token(scope) assert token.token == access_token_b