async def test_single_account_matching_username(): """one cached account, username specified, username matches -> credential should auth that account""" upn = "spam@eggs" refresh_token = "refresh-token" scope = "scope" account = get_account_event(uid="uid_a", utid="utid", username=upn, refresh_token=refresh_token) cache = populated_cache(account) expected_token = "***" transport = async_validating_transport( requests=[Request(required_data={"refresh_token": refresh_token, "scope": scope})], responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport, username=upn) token = await credential.get_token(scope) assert token.token == expected_token
async def test_app_service_user_assigned_identity(): """App Service 2017-09-01: MSI_ENDPOINT, MSI_SECRET set""" access_token = "****" expires_on = 42 client_id = "some-guid" expected_token = AccessToken(access_token, expires_on) endpoint = "http://localhost:42/token" secret = "expected-secret" scope = "scope" transport = async_validating_transport( requests=[ Request( base_url=endpoint, method="GET", required_headers={ "secret": secret, "User-Agent": USER_AGENT }, required_params={ "api-version": "2017-09-01", "clientid": client_id, "resource": scope }, ) ], responses=[ mock_response( json_payload={ "access_token": access_token, "expires_on": "01/01/1970 00:00:{} +00:00".format( expires_on), "resource": scope, "token_type": "Bearer", }) ], ) with mock.patch( "os.environ", { EnvironmentVariables.MSI_ENDPOINT: endpoint, EnvironmentVariables.MSI_SECRET: secret }): token = await ManagedIdentityCredential( client_id=client_id, transport=transport).get_token(scope) assert token == expected_token
async def test_imds(): access_token = "****" expires_on = 42 expected_token = AccessToken(access_token, expires_on) scope = "scope" transport = async_validating_transport( requests=[ Request( url=Endpoints.IMDS ), # first request should be availability probe => match only the URL Request( base_url=Endpoints.IMDS, method="GET", required_headers={ "Metadata": "true", "User-Agent": USER_AGENT }, required_params={ "api-version": "2018-02-01", "resource": scope }, ), ], responses=[ # probe receives error response mock_response(status_code=400, json_payload={"error": "this is an error message"}), mock_response( json_payload={ "access_token": access_token, "expires_in": 42, "expires_on": expires_on, "ext_expires_in": 42, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", }), ], ) # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential with mock.patch.dict("os.environ", clear=True): token = await ManagedIdentityCredential(transport=transport ).get_token(scope) assert token == expected_token
async def test_no_user_settings(): """the credential should default to Public Cloud and "organizations" tenant when it can't read VS Code settings""" transport = async_validating_transport( requests=[ Request(base_url="https://{}/{}".format( AzureAuthorityHosts.AZURE_PUBLIC_CLOUD, "organizations")) ], responses=[ mock_response(json_payload=build_aad_response(access_token="**")) ], ) credential = get_credential(transport=transport) with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): await credential.get_token("scope") assert transport.send.call_count == 1
async def test_app_service(): """App Service environment: MSI_ENDPOINT, MSI_SECRET set""" access_token = "****" expires_on = 42 expected_token = AccessToken(access_token, expires_on) url = "http://localhost:42/token" secret = "expected-secret" scope = "scope" transport = async_validating_transport( requests=[ Request( url, method="GET", required_headers={ "Metadata": "true", "secret": secret, "User-Agent": USER_AGENT }, required_params={ "api-version": "2017-09-01", "resource": scope }, ) ], responses=[ mock_response( json_payload={ "access_token": access_token, "expires_on": expires_on, "resource": scope, "token_type": "Bearer", }) ], ) with mock.patch( "os.environ", { EnvironmentVariables.MSI_ENDPOINT: url, EnvironmentVariables.MSI_SECRET: secret }): token = await ManagedIdentityCredential(transport=transport ).get_token(scope) assert token == expected_token
async def test_imds_url_override(): url = "https://localhost/token" expected_token = "***" scope = "scope" now = int(time.time()) transport = async_validating_transport( requests=[ Request( base_url=url, method="GET", required_headers={ "Metadata": "true", "User-Agent": USER_AGENT }, required_params={ "api-version": "2018-02-01", "resource": scope }, ), ], responses=[ mock_response( json_payload={ "access_token": expected_token, "expires_in": 42, "expires_on": now + 42, "ext_expires_in": 42, "not_before": now, "resource": scope, "token_type": "Bearer", }), ], ) with mock.patch.dict( "os.environ", {EnvironmentVariables.AZURE_POD_IDENTITY_TOKEN_URL: url}, clear=True): credential = ImdsCredential(transport=transport) token = await credential.get_token(scope) assert token.token == expected_token
async def test_two_accounts_username_specified(): """two cached accounts, username specified, one account matches -> credential should auth that account""" scope = "scope" expected_refresh_token = "refresh-token-a" upn_a = "a@foo" upn_b = "b@foo" account_a = get_account_event(username=upn_a, uid="uid_a", utid="utid", refresh_token=expected_refresh_token) account_b = get_account_event(username=upn_b, uid="uid_b", utid="utid", refresh_token="refresh_token_b") cache = populated_cache(account_a, account_b) expected_token = "***" transport = async_validating_transport( requests=[Request(required_data={"refresh_token": expected_refresh_token, "scope": scope})], responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(username=upn_a, _cache=cache, transport=transport) token = await credential.get_token(scope) assert token.token == expected_token
async def test_identity_config_cloud_shell(): param_name, param_value = "foo", "bar" access_token = "****" expires_on = 42 expected_token = AccessToken(access_token, expires_on) endpoint = "http://localhost:42/token" scope = "scope" transport = async_validating_transport( requests=[ Request( base_url=endpoint, method="POST", required_headers={ "Metadata": "true", "User-Agent": USER_AGENT }, required_data={ "resource": scope, param_name: param_value }, ) ], responses=[ mock_response( json_payload={ "access_token": access_token, "expires_in": 0, "expires_on": expires_on, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", }) ], ) with mock.patch.dict(MsiCredential.__module__ + ".os.environ", {EnvironmentVariables.MSI_ENDPOINT: endpoint}, clear=True): credential = MsiCredential(identity_config={param_name: param_value}, transport=transport) token = await credential.get_token(scope) assert token == expected_token
async def test_custom_hooks(environ): """The credential's pipeline should include azure-core's CustomHookPolicy""" scope = "scope" expected_token = "***" request_hook = mock.Mock() response_hook = mock.Mock() now = int(time.time()) expected_response = mock_response( json_payload={ "access_token": expected_token, "expires_in": 3600, "expires_on": now + 3600, "ext_expires_in": 3600, "not_before": now, "resource": scope, "token_type": "Bearer", }) transport = async_validating_transport(requests=[Request()] * 2, responses=[expected_response] * 2) with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, environ, clear=True): credential = ManagedIdentityCredential(transport=transport, raw_request_hook=request_hook, raw_response_hook=response_hook) await credential.get_token(scope) if environ: # some environment variables are set, so we're not mocking IMDS and should expect 1 request assert request_hook.call_count == 1 assert response_hook.call_count == 1 args, kwargs = response_hook.call_args pipeline_response = args[0] assert pipeline_response.http_response == expected_response else: # we're mocking IMDS and should expect 2 requests assert request_hook.call_count == 2 assert response_hook.call_count == 2 responses = [ args[0].http_response for args, _ in response_hook.call_args_list ] assert responses == [expected_response] * 2
async def test_cloud_shell_user_assigned_identity(): """Cloud Shell environment: only MSI_ENDPOINT set""" access_token = "****" expires_on = 42 client_id = "some-guid" expected_token = AccessToken(access_token, expires_on) endpoint = "http://localhost:42/token" scope = "scope" transport = async_validating_transport( requests=[ Request( base_url=endpoint, method="POST", required_headers={ "Metadata": "true", "User-Agent": USER_AGENT }, required_data={ "client_id": client_id, "resource": scope }, ) ], responses=[ mock_response( json_payload={ "access_token": access_token, "expires_in": 0, "expires_on": expires_on, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", }) ], ) with mock.patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: endpoint}): token = await ManagedIdentityCredential( client_id=client_id, transport=transport).get_token(scope) assert token == expected_token
async def test_reads_cloud_settings(cloud, authority): """the credential should read authority and tenant from VS Code settings when an application doesn't specify them""" expected_tenant = "tenant-id" user_settings = {"azure.cloud": cloud, "azure.tenant": expected_tenant} transport = async_validating_transport( requests=[ Request( base_url="https://{}/{}".format(authority, expected_tenant)) ], responses=[ mock_response(json_payload=build_aad_response(access_token="**")) ], ) credential = get_credential(user_settings, transport=transport) with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): await credential.get_token("scope") assert transport.send.call_count == 1
async def test_identity_config(): param_name, param_value = "foo", "bar" access_token = "****" expires_on = 42 expected_token = AccessToken(access_token, expires_on) scope = "scope" client_id = "some-guid" transport = async_validating_transport( requests=[ Request(base_url=IMDS_URL), Request( base_url=IMDS_URL, method="GET", required_headers={"Metadata": "true", "User-Agent": USER_AGENT}, required_params={"api-version": "2018-02-01", "resource": scope, param_name: param_value}, ), ], responses=[ mock_response(status_code=400, json_payload={"error": "this is an error message"}), mock_response( json_payload={ "access_token": access_token, "expires_in": 42, "expires_on": expires_on, "ext_expires_in": 42, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", } ), ], ) credential = ImdsCredential(client_id=client_id, identity_config={param_name: param_value}, transport=transport) token = await credential.get_token(scope) assert token == expected_token
async def test_imds_user_assigned_identity(): access_token = "****" expires_on = 42 expected_token = AccessToken(access_token, expires_on) scope = "scope" client_id = "some-guid" transport = async_validating_transport( requests=[ Request(base_url=IMDS_AUTHORITY + IMDS_TOKEN_PATH), Request( base_url=IMDS_AUTHORITY + IMDS_TOKEN_PATH, method="GET", required_headers={"Metadata": "true", "User-Agent": USER_AGENT}, required_params={"api-version": "2018-02-01", "client_id": client_id, "resource": scope}, ), ], responses=[ # probe receives error response mock_response(status_code=400, json_payload={"error": "this is an error message"}), mock_response( json_payload={ "access_token": access_token, "client_id": client_id, "expires_in": 42, "expires_on": expires_on, "ext_expires_in": 42, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", } ), ], ) # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential with mock.patch.dict("os.environ", clear=True): token = await ManagedIdentityCredential(client_id=client_id, transport=transport).get_token(scope) assert token == expected_token
async def test_client_secret_environment_credential(): client_id = "fake-client-id" secret = "fake-client-secret" tenant_id = "fake-tenant-id" access_token = "***" transport = async_validating_transport( requests=[ Request(url_substring=tenant_id, required_data={ "client_id": client_id, "client_secret": secret }) ], responses=[ mock_response( json_payload={ "token_type": "Bearer", "expires_in": 42, "ext_expires_in": 42, "access_token": access_token, }) ], ) environment = { EnvironmentVariables.AZURE_CLIENT_ID: client_id, EnvironmentVariables.AZURE_CLIENT_SECRET: secret, EnvironmentVariables.AZURE_TENANT_ID: tenant_id, } with patch("os.environ", environment): token = await EnvironmentCredential(transport=transport ).get_token("scope") # not validating expires_on because doing so requires monkeypatching time, and this is tested elsewhere assert token.token == access_token
async def test_authentication_record(): tenant_id = "tenant-id" client_id = "client-id" authority = "localhost" object_id = "object-id" home_account_id = object_id + "." + tenant_id username = "******" record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) expected_access_token = "****" expected_refresh_token = "**" account = get_account_event(username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token) cache = populated_cache(account) transport = async_validating_transport( requests=[ Request(authority=authority, required_data={"refresh_token": expected_refresh_token}) ], responses=[ mock_response(json_payload=build_aad_response( access_token=expected_access_token)) ], ) credential = SharedTokenCacheCredential(_authentication_record=record, transport=transport, _cache=cache) token = await credential.get_token("scope") assert token.token == expected_access_token
async def test_same_tenant_different_usernames(): """two cached accounts, same tenant, different usernames""" access_token_a = "access-token-a" access_token_b = "access-token-b" refresh_token_a = "refresh-token-a" refresh_token_b = "refresh-token-b" upn_a = "spam@eggs" upn_b = "eggs@spam" tenant_id = "the-tenant" account_a = get_account_event(username=upn_a, uid="another-guid", utid=tenant_id, refresh_token=refresh_token_a) account_b = get_account_event(username=upn_b, uid="more-guid", utid=tenant_id, refresh_token=refresh_token_b) cache = populated_cache(account_a, account_b) # with no username specified the credential can't select an identity transport = Mock( side_effect=Exception()) # (so it shouldn't use the network) credential = SharedTokenCacheCredential(tenant_id=tenant_id, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: await credential.get_token("scope") assert ex.value.message.startswith( MULTIPLE_MATCHING_ACCOUNTS[:MULTIPLE_MATCHING_ACCOUNTS.index("{")]) assert tenant_id in ex.value.message # with a username specified, the credential should auth the matching account scope = "scope" transport = async_validating_transport( requests=[ Request(required_data={ "refresh_token": refresh_token_b, "scope": scope }) ], responses=[ mock_response(json_payload=build_aad_response( access_token=access_token_a)) ], ) credential = SharedTokenCacheCredential(username=upn_b, _cache=cache, transport=transport) token = await credential.get_token(scope) assert token.token == access_token_a transport = async_validating_transport( requests=[ Request(required_data={ "refresh_token": refresh_token_a, "scope": scope }) ], responses=[ mock_response(json_payload=build_aad_response( access_token=access_token_a)) ], ) credential = SharedTokenCacheCredential(username=upn_a, _cache=cache, transport=transport) token = await credential.get_token(scope) assert token.token == access_token_a
async def test_azure_arc(tmpdir): """Azure Arc 2019-11-01""" access_token = "****" api_version = "2019-11-01" expires_on = 42 identity_endpoint = "http://localhost:42/token" imds_endpoint = "http://localhost:42" scope = "scope" secret_key = "XXXX" key_file = tmpdir.mkdir("key").join("key_file.key") key_file.write(secret_key) assert key_file.read() == secret_key key_path = os.path.join(key_file.dirname, key_file.basename) transport = async_validating_transport( requests=[ Request( base_url=identity_endpoint, method="GET", required_headers={"Metadata": "true"}, required_params={ "api-version": api_version, "resource": scope }, ), Request( base_url=identity_endpoint, method="GET", required_headers={ "Metadata": "true", "Authorization": "Basic {}".format(secret_key) }, required_params={ "api-version": api_version, "resource": scope }, ), ], responses=[ # first response gives path to authentication key mock_response(status_code=401, headers={ "WWW-Authenticate": "Basic realm={}".format(key_path) }), mock_response( json_payload={ "access_token": access_token, "expires_on": expires_on, "resource": scope, "token_type": "Bearer", }), ], ) with mock.patch( "os.environ", { EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint }, ): token = await ManagedIdentityCredential(transport=transport ).get_token(scope) assert token.token == access_token assert token.expires_on == expires_on
async def test_azure_ml(): """Azure ML: MSI_ENDPOINT, MSI_SECRET set (like App Service 2017-09-01 but with a different response format)""" expected_token = AccessToken("****", int(time.time()) + 3600) url = "http://localhost:42/token" secret = "expected-secret" scope = "scope" client_id = "client" transport = async_validating_transport( requests=[ Request( url, method="GET", required_headers={ "secret": secret, "User-Agent": USER_AGENT }, required_params={ "api-version": "2017-09-01", "resource": scope }, ), Request( url, method="GET", required_headers={ "secret": secret, "User-Agent": USER_AGENT }, required_params={ "api-version": "2017-09-01", "resource": scope, "clientid": client_id }, ), ], responses=[ mock_response( json_payload={ "access_token": expected_token.token, "expires_in": 3600, "expires_on": expected_token.expires_on, "resource": scope, "token_type": "Bearer", }) ] * 2, ) with mock.patch.dict( MANAGED_IDENTITY_ENVIRON, { EnvironmentVariables.MSI_ENDPOINT: url, EnvironmentVariables.MSI_SECRET: secret }, clear=True, ): credential = ManagedIdentityCredential(transport=transport) token = await credential.get_token(scope) assert token.token == expected_token.token assert token.expires_on == expected_token.expires_on credential = ManagedIdentityCredential(transport=transport, client_id=client_id) token = await credential.get_token(scope) assert token.token == expected_token.token assert token.expires_on == expected_token.expires_on
async def test_same_username_different_tenants(): """two cached accounts, same username, different tenants""" access_token_a = "access-token-a" access_token_b = "access-token-b" refresh_token_a = "refresh-token-a" refresh_token_b = "refresh-token-b" upn = "spam@eggs" tenant_a = "tenant-a" tenant_b = "tenant-b" account_a = get_account_event(username=upn, uid="another-guid", utid=tenant_a, refresh_token=refresh_token_a) account_b = get_account_event(username=upn, uid="more-guid", utid=tenant_b, refresh_token=refresh_token_b) cache = populated_cache(account_a, account_b) # with no tenant specified the credential can't select an identity transport = Mock( side_effect=Exception()) # (so it shouldn't use the network) credential = SharedTokenCacheCredential(username=upn, _cache=cache, transport=transport) with pytest.raises(ClientAuthenticationError) as ex: await credential.get_token("scope") # error message should indicate multiple matching accounts, and list discovered accounts assert ex.value.message.startswith( MULTIPLE_MATCHING_ACCOUNTS[:MULTIPLE_MATCHING_ACCOUNTS.index("{")]) discovered_accounts = ex.value.message.splitlines()[-1] assert discovered_accounts.count(upn) == 2 assert tenant_a in discovered_accounts and tenant_b in discovered_accounts # with tenant specified, the credential should auth the matching account scope = "scope" transport = async_validating_transport( requests=[ Request(required_data={ "refresh_token": refresh_token_a, "scope": scope }) ], responses=[ mock_response(json_payload=build_aad_response( access_token=access_token_a)) ], ) credential = SharedTokenCacheCredential(tenant_id=tenant_a, _cache=cache, transport=transport) token = await credential.get_token(scope) assert token.token == access_token_a transport = async_validating_transport( requests=[ Request(required_data={ "refresh_token": refresh_token_b, "scope": scope }) ], responses=[ mock_response(json_payload=build_aad_response( access_token=access_token_b)) ], ) credential = SharedTokenCacheCredential(tenant_id=tenant_b, _cache=cache, transport=transport) token = await credential.get_token(scope) assert token.token == access_token_b
async def test_app_service_user_assigned_identity(): """App Service 2017-09-01: MSI_ENDPOINT, MSI_SECRET set""" expected_token = "****" expires_on = 42 client_id = "some-guid" endpoint = "http://localhost:42/token" secret = "expected-secret" scope = "scope" param_name, param_value = "foo", "bar" transport = async_validating_transport( requests=[ Request( base_url=endpoint, method="GET", required_headers={ "secret": secret, "User-Agent": USER_AGENT }, required_params={ "api-version": "2017-09-01", "clientid": client_id, "resource": scope }, ), Request( base_url=endpoint, method="GET", required_headers={ "secret": secret, "User-Agent": USER_AGENT }, required_params={ "api-version": "2017-09-01", "resource": scope, param_name: param_value }, ), ], responses=[ mock_response( json_payload={ "access_token": expected_token, "expires_on": "01/01/1970 00:00:{} +00:00".format( expires_on), "resource": scope, "token_type": "Bearer", }) ] * 2, ) with mock.patch.dict( MANAGED_IDENTITY_ENVIRON, { EnvironmentVariables.MSI_ENDPOINT: endpoint, EnvironmentVariables.MSI_SECRET: secret }, clear=True, ): credential = ManagedIdentityCredential(client_id=client_id, transport=transport) token = await credential.get_token(scope) assert token.token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential( client_id=client_id, transport=transport, identity_config={param_name: param_value}) token = await credential.get_token(scope) assert token.token == expected_token assert token.expires_on == expires_on
async def test_app_service_2017_09_01(): """test parsing of App Service MSI 2017-09-01's eccentric platform-dependent expires_on strings""" access_token = "****" expires_on = 42 expected_token = AccessToken(access_token, expires_on) url = "http://localhost:42/token" secret = "expected-secret" scope = "scope" transport = async_validating_transport( requests=[ Request( url, method="GET", required_headers={ "secret": secret, "User-Agent": USER_AGENT }, required_params={ "api-version": "2017-09-01", "resource": scope }, ) ] * 2, responses=[ mock_response( json_payload={ "access_token": access_token, "expires_on": "01/01/1970 00:00:{} +00:00".format( expires_on), # linux format "resource": scope, "token_type": "Bearer", }), mock_response( json_payload={ "access_token": access_token, "expires_on": "1/1/1970 12:00:{} AM +00:00".format( expires_on), # windows format "resource": scope, "token_type": "Bearer", }), ], ) with mock.patch.dict( MANAGED_IDENTITY_ENVIRON, { EnvironmentVariables.MSI_ENDPOINT: url, EnvironmentVariables.MSI_SECRET: secret }, clear=True, ): token = await ManagedIdentityCredential(transport=transport ).get_token(scope) assert token == expected_token assert token.expires_on == expires_on token = await ManagedIdentityCredential(transport=transport ).get_token(scope) assert token == expected_token assert token.expires_on == expires_on
async def test_prefers_app_service_2017_09_01(): """When the environment is configured for both App Service versions, the credential should prefer 2017-09-01 Support for 2019-08-01 was removed due to https://github.com/Azure/azure-sdk-for-python/issues/14670. This test should be removed when that support is added back. """ access_token = "****" expires_on = 42 expected_token = AccessToken(access_token, expires_on) url = "http://localhost:42/token" secret = "expected-secret" scope = "scope" transport = async_validating_transport( requests=[ Request( url, method="GET", required_headers={ "secret": secret, "User-Agent": USER_AGENT }, required_params={ "api-version": "2017-09-01", "resource": scope }, ) ] * 2, responses=[ mock_response( json_payload={ "access_token": access_token, "expires_on": "01/01/1970 00:00:{} +00:00".format( expires_on), # linux format "resource": scope, "token_type": "Bearer", }), mock_response( json_payload={ "access_token": access_token, "expires_on": "1/1/1970 12:00:{} AM +00:00".format( expires_on), # windows format "resource": scope, "token_type": "Bearer", }), ], ) with mock.patch.dict( MANAGED_IDENTITY_ENVIRON, { EnvironmentVariables.IDENTITY_ENDPOINT: url, EnvironmentVariables.IDENTITY_HEADER: secret, EnvironmentVariables.MSI_ENDPOINT: url, EnvironmentVariables.MSI_SECRET: secret, }, clear=True, ): credential = ManagedIdentityCredential(transport=transport) token = await credential.get_token(scope) assert token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential(transport=transport) token = await credential.get_token(scope) assert token == expected_token assert token.expires_on == expires_on
async def test_cloud_shell_user_assigned_identity(): """Cloud Shell environment: only MSI_ENDPOINT set""" expected_token = "****" expires_on = 42 client_id = "some-guid" endpoint = "http://localhost:42/token" scope = "scope" param_name, param_value = "foo", "bar" transport = async_validating_transport( requests=[ Request( base_url=endpoint, method="POST", required_headers={ "Metadata": "true", "User-Agent": USER_AGENT }, required_data={ "client_id": client_id, "resource": scope }, ), Request( base_url=endpoint, method="POST", required_headers={ "Metadata": "true", "User-Agent": USER_AGENT }, required_data={ "resource": scope, param_name: param_value }, ), ], responses=[ mock_response( json_payload={ "access_token": expected_token, "expires_in": 0, "expires_on": expires_on, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", }) ] * 2, ) with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: endpoint}, clear=True): credential = ManagedIdentityCredential(client_id=client_id, transport=transport) token = await credential.get_token(scope) assert token.token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential( transport=transport, identity_config={param_name: param_value}) token = await credential.get_token(scope) assert token.token == expected_token assert token.expires_on == expires_on
async def test_token_exchange(tmpdir): exchange_token = "exchange-token" token_file = tmpdir.join("token") token_file.write(exchange_token) access_token = "***" authority = "https://localhost" default_client_id = "default_client_id" tenant = "tenant_id" scope = "scope" success_response = mock_response( json_payload={ "access_token": access_token, "expires_in": 3600, "ext_expires_in": 3600, "expires_on": int(time.time()) + 3600, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", } ) transport = async_validating_transport( requests=[ Request( base_url=authority, method="POST", required_data={ "client_assertion": exchange_token, "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", "client_id": default_client_id, "grant_type": "client_credentials", "scope": scope, }, ) ], responses=[success_response], ) mock_environ = { EnvironmentVariables.AZURE_AUTHORITY_HOST: authority, EnvironmentVariables.AZURE_CLIENT_ID: default_client_id, EnvironmentVariables.AZURE_TENANT_ID: tenant, EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath, } # credential should default to AZURE_CLIENT_ID with mock.patch.dict("os.environ", mock_environ, clear=True): credential = ManagedIdentityCredential(transport=transport) token = await credential.get_token(scope) assert token.token == access_token # client_id kwarg should override AZURE_CLIENT_ID nondefault_client_id = "non" + default_client_id transport = async_validating_transport( requests=[ Request( base_url=authority, method="POST", required_data={ "client_assertion": exchange_token, "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", "client_id": nondefault_client_id, "grant_type": "client_credentials", "scope": scope, }, ) ], responses=[success_response], ) with mock.patch.dict("os.environ", mock_environ, clear=True): credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport) token = await credential.get_token(scope) assert token.token == access_token # AZURE_CLIENT_ID may not have a value, in which case client_id is required transport = async_validating_transport( requests=[ Request( base_url=authority, method="POST", required_data={ "client_assertion": exchange_token, "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", "client_id": nondefault_client_id, "grant_type": "client_credentials", "scope": scope, }, ) ], responses=[success_response], ) with mock.patch.dict( "os.environ", { EnvironmentVariables.AZURE_AUTHORITY_HOST: authority, EnvironmentVariables.AZURE_TENANT_ID: tenant, EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath, }, clear=True, ): with pytest.raises(ValueError): ManagedIdentityCredential() credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport) token = await credential.get_token(scope) assert token.token == access_token
async def test_auth_code_credential(): client_id = "client id" tenant_id = "tenant" expected_code = "auth code" redirect_uri = "https://localhost" expected_access_token = "access" expected_refresh_token = "refresh" expected_scope = "scope" auth_response = build_aad_response(access_token=expected_access_token, refresh_token=expected_refresh_token) transport = async_validating_transport( requests=[ Request( # first call should redeem the auth code url_substring=tenant_id, required_data={ "client_id": client_id, "code": expected_code, "grant_type": "authorization_code", "redirect_uri": redirect_uri, "scope": expected_scope, }, ), Request( # third call should redeem the refresh token url_substring=tenant_id, required_data={ "client_id": client_id, "grant_type": "refresh_token", "refresh_token": expected_refresh_token, "scope": expected_scope, }, ), ], responses=[mock_response(json_payload=auth_response)] * 2, ) cache = msal.TokenCache() credential = AuthorizationCodeCredential( client_id=client_id, tenant_id=tenant_id, authorization_code=expected_code, redirect_uri=redirect_uri, transport=transport, cache=cache, ) # first call should redeem the auth code token = await credential.get_token(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 1 # no auth code -> credential should return cached token token = await credential.get_token(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 1 # no auth code, no cached token -> credential should redeem refresh token cached_access_token = cache.find(cache.CredentialType.ACCESS_TOKEN)[0] cache.remove_at(cached_access_token) token = await credential.get_token(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 2