async def test_client_id_none(): """the credential should ignore client_id=None""" expected_access_token = "****" scope = "scope" async def send(request, **_): assert "client_id" not in request.query # IMDS if request.data: assert "client_id" not in request.body # Cloud Shell return mock_response( json_payload=(build_aad_response(access_token=expected_access_token, expires_on="42", resource=scope)) ) with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, {}, clear=True): credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send)) token = await credential.get_token(scope) assert token.token == expected_access_token with mock.patch.dict( MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: "https://localhost"}, clear=True ): credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send)) token = await credential.get_token(scope) assert token.token == expected_access_token
async def test_app_service_user_assigned_identity(): """App Service 2017-09-01: MSI_ENDPOINT, MSI_SECRET set""" expected_token = "****" expires_on = 42 client_id = "some-guid" endpoint = "http://localhost:42/token" secret = "expected-secret" scope = "scope" param_name, param_value = "foo", "bar" transport = async_validating_transport( requests=[ Request( base_url=endpoint, method="GET", required_headers={"secret": secret, "User-Agent": USER_AGENT}, required_params={"api-version": "2017-09-01", "clientid": client_id, "resource": scope}, ), Request( base_url=endpoint, method="GET", required_headers={"secret": secret, "User-Agent": USER_AGENT}, required_params={ "api-version": "2017-09-01", "resource": scope, param_name: param_value, }, ), ], responses=[ mock_response( json_payload={ "access_token": expected_token, "expires_on": "01/01/1970 00:00:{} +00:00".format(expires_on), "resource": scope, "token_type": "Bearer", } ) ] * 2, ) with mock.patch.dict( MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: endpoint, EnvironmentVariables.MSI_SECRET: secret}, clear=True, ): credential = ManagedIdentityCredential(client_id=client_id, transport=transport) token = await credential.get_token(scope) assert token.token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential( client_id=client_id, transport=transport, identity_config={param_name: param_value} ) token = await credential.get_token(scope) assert token.token == expected_token assert token.expires_on == expires_on
async def test_app_service_2017_09_01(): """When the environment for 2019-08-01 is not configured, 2017-09-01 should be used.""" access_token = "****" expires_on = 42 expected_token = AccessToken(access_token, expires_on) url = "http://localhost:42/token" secret = "expected-secret" scope = "scope" transport = async_validating_transport( requests=[ Request( url, method="GET", required_headers={"secret": secret, "User-Agent": USER_AGENT}, required_params={"api-version": "2017-09-01", "resource": scope}, ) ] * 2, responses=[ mock_response( json_payload={ "access_token": access_token, "expires_on": "01/01/1970 00:00:{} +00:00".format(expires_on), # linux format "resource": scope, "token_type": "Bearer", } ), mock_response( json_payload={ "access_token": access_token, "expires_on": "1/1/1970 12:00:{} AM +00:00".format(expires_on), # windows format "resource": scope, "token_type": "Bearer", } ), ], ) with mock.patch.dict( MANAGED_IDENTITY_ENVIRON, { EnvironmentVariables.MSI_ENDPOINT: url, EnvironmentVariables.MSI_SECRET: secret, }, clear=True, ): credential = ManagedIdentityCredential(transport=transport) token = await credential.get_token(scope) assert token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential(transport=transport) token = await credential.get_token(scope) assert token == expected_token assert token.expires_on == expires_on
async def test_azure_ml(): """Azure ML: MSI_ENDPOINT, MSI_SECRET set (like App Service 2017-09-01 but with a different response format)""" expected_token = AccessToken("****", int(time.time()) + 3600) url = "http://localhost:42/token" secret = "expected-secret" scope = "scope" client_id = "client" transport = async_validating_transport( requests=[ Request( url, method="GET", required_headers={"secret": secret, "User-Agent": USER_AGENT}, required_params={"api-version": "2017-09-01", "resource": scope}, ), Request( url, method="GET", required_headers={"secret": secret, "User-Agent": USER_AGENT}, required_params={"api-version": "2017-09-01", "resource": scope, "clientid": client_id}, ), ], responses=[ mock_response( json_payload={ "access_token": expected_token.token, "expires_in": 3600, "expires_on": expected_token.expires_on, "resource": scope, "token_type": "Bearer", } ) ] * 2, ) with mock.patch.dict( MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: url, EnvironmentVariables.MSI_SECRET: secret}, clear=True, ): credential = ManagedIdentityCredential(transport=transport) token = await credential.get_token(scope) assert token.token == expected_token.token assert token.expires_on == expected_token.expires_on credential = ManagedIdentityCredential(transport=transport, client_id=client_id) token = await credential.get_token(scope) assert token.token == expected_token.token assert token.expires_on == expected_token.expires_on
async def test_cloud_shell_user_assigned_identity(): """Cloud Shell environment: only MSI_ENDPOINT set""" expected_token = "****" expires_on = 42 client_id = "some-guid" endpoint = "http://localhost:42/token" scope = "scope" param_name, param_value = "foo", "bar" transport = async_validating_transport( requests=[ Request( base_url=endpoint, method="POST", required_headers={"Metadata": "true", "User-Agent": USER_AGENT}, required_data={"client_id": client_id, "resource": scope}, ), Request( base_url=endpoint, method="POST", required_headers={"Metadata": "true", "User-Agent": USER_AGENT}, required_data={"resource": scope, param_name: param_value}, ), ], responses=[ mock_response( json_payload={ "access_token": expected_token, "expires_in": 0, "expires_on": expires_on, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", } ) ] * 2, ) with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: endpoint}, clear=True): credential = ManagedIdentityCredential(client_id=client_id, transport=transport) token = await credential.get_token(scope) assert token.token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential(transport=transport, identity_config={param_name: param_value}) token = await credential.get_token(scope) assert token.token == expected_token assert token.expires_on == expires_on
async def test_managed_identity_live(live_managed_identity_config): credential = ManagedIdentityCredential(client_id=live_managed_identity_config["client_id"]) # do something with Key Vault to verify the credential can get a valid token client = SecretClient(live_managed_identity_config["vault_url"], credential, logging_enable=True) secret = await client.set_secret("managed-identity-test-secret", "value") await client.delete_secret(secret.name)
async def test_client_id_none_app_service_2017_09_01(): """The credential should ignore client_id=None. App Service 2017-09-01 must be tested separately due to its eccentric expires_on format. """ expected_access_token = "****" scope = "scope" async def send(request, **_): assert "client_id" not in request.query assert "clientid" not in request.query return mock_response(json_payload=( build_aad_response(access_token=expected_access_token, expires_on="01/01/1970 00:00:42 +00:00", resource=scope))) with mock.patch.dict( MANAGED_IDENTITY_ENVIRON, { EnvironmentVariables.MSI_ENDPOINT: "https://localhost", EnvironmentVariables.MSI_SECRET: "secret" }, clear=True, ): credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send)) token = await credential.get_token(scope) assert token.token == expected_access_token
async def test_close(environ): transport = AsyncMockTransport() with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, environ, clear=True): credential = ManagedIdentityCredential(transport=transport) await credential.close() assert transport.__aexit__.call_count == 1
async def test_managed_identity_live(live_managed_identity_config): credential = ManagedIdentityCredential( client_id=live_managed_identity_config["client_id"]) # do something with Key Vault to verify the credential can get a valid token client = SecretClient(live_managed_identity_config["vault_url"], credential, logging_enable=True) async for _ in client.list_properties_of_secrets(): pass
async def test_cloud_shell_live(cloud_shell): credential = ManagedIdentityCredential() token = credential.get_token("https://vault.azure.net") # Validate the token by sending a request to the Key Vault. The request is manual because azure-keyvault-secrets # can't authenticate in Cloud Shell; the MSI endpoint there doesn't support AADv2 scopes. policies = [ ContentDecodePolicy(), AsyncRedirectPolicy(), AsyncRetryPolicy(), HttpLoggingPolicy() ] client = AsyncPipelineClient(cloud_shell["vault_url"], policies=policies) list_secrets = client.get( "secrets", headers={"Authorization": "Bearer " + token.token}, params={"api-version": "7.0"}) async with client: await client._pipeline.run(list_secrets)
async def test_context_manager(environ): transport = AsyncMockTransport() with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, environ, clear=True): credential = ManagedIdentityCredential(transport=transport) async with credential: assert transport.__aenter__.call_count == 1 assert transport.__aexit__.call_count == 0 assert transport.__aenter__.call_count == 1 assert transport.__aexit__.call_count == 1
async def test_azure_arc_client_id(): """Azure Arc doesn't support user-assigned managed identity""" with mock.patch( "os.environ", { EnvironmentVariables.IDENTITY_ENDPOINT: "http://localhost:42/token", EnvironmentVariables.IMDS_ENDPOINT: "http://localhost:42", }): credential = ManagedIdentityCredential(client_id="some-guid") with pytest.raises(ClientAuthenticationError): await credential.get_token("scope")
def _initialize_credentials(self): if self.subscription_id is not None \ and self.arm_base_url is not None: if self.vscode_tenant_id is None: self.vscode_tenant_id = self._get_tenant_id( arm_base_url=self.arm_base_url, subscription_id=self.subscription_id) if self.shared_cache_tenant_id is None: self.shared_cache_tenant_id = self._get_tenant_id( arm_base_url=self.arm_base_url, subscription_id=self.subscription_id) if self.interactive_browser_tenant_id is None: self.interactive_browser_tenant_id = self._get_tenant_id( arm_base_url=self.arm_base_url, subscription_id=self.subscription_id) credentials = [] # type: List[AsyncTokenCredential] if not self.exclude_token_file_credential: credentials.append(_TokenFileCredential()) if not self.exclude_environment_credential: credentials.append(EnvironmentCredential(authority=self.authority)) if not self.exclude_managed_identity_credential: credentials.append( ManagedIdentityCredential( client_id=self.managed_identity_client_id)) if not self.exclude_shared_token_cache_credential and SharedTokenCacheCredential.supported( ): try: # username and/or tenant_id are only required when the cache contains tokens for multiple identities shared_cache = SharedTokenCacheCredential( username=self.shared_cache_username, tenant_id=self.shared_cache_tenant_id, authority=self.authority) credentials.append(shared_cache) except Exception as ex: # pylint:disable=broad-except _LOGGER.info("Shared token cache is unavailable: '%s'", ex) if not self.exclude_visual_studio_code_credential: credentials.append( VisualStudioCodeCredential(tenant_id=self.vscode_tenant_id)) if not self.exclude_cli_credential: credentials.append(AzureCliCredential()) if not self.exclude_powershell_credential: credentials.append(AzurePowerShellCredential()) if not self.exclude_interactive_browser_credential: credentials.append( InteractiveBrowserCredential( tenant_id=self.interactive_browser_tenant_id)) if not self.exclude_device_code_credential: credentials.append( DeviceCodeCredential( tenant_id=self.interactive_browser_tenant_id)) self.credentials = credentials
async def test_token_exchange_tenant_id(tmpdir): exchange_token = "exchange-token" token_file = tmpdir.join("token") token_file.write(exchange_token) access_token = "***" authority = "https://localhost" default_client_id = "default_client_id" tenant = "tenant_id" scope = "scope" success_response = mock_response( json_payload={ "access_token": access_token, "expires_in": 3600, "ext_expires_in": 3600, "expires_on": int(time.time()) + 3600, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", } ) transport = async_validating_transport( requests=[ Request( base_url=authority, method="POST", required_data={ "client_assertion": exchange_token, "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", "client_id": default_client_id, "grant_type": "client_credentials", "scope": scope, }, ) ], responses=[success_response], ) mock_environ = { EnvironmentVariables.AZURE_AUTHORITY_HOST: authority, EnvironmentVariables.AZURE_CLIENT_ID: default_client_id, EnvironmentVariables.AZURE_TENANT_ID: tenant, EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath, } # credential should default to AZURE_CLIENT_ID with mock.patch.dict("os.environ", mock_environ, clear=True): credential = ManagedIdentityCredential(transport=transport) token = await credential.get_token(scope, tenant_id="tenant_id") assert token.token == access_token
async def test_custom_hooks(environ): """The credential's pipeline should include azure-core's CustomHookPolicy""" scope = "scope" expected_token = "***" request_hook = mock.Mock() response_hook = mock.Mock() now = int(time.time()) expected_response = mock_response( json_payload={ "access_token": expected_token, "expires_in": 3600, "expires_on": now + 3600, "ext_expires_in": 3600, "not_before": now, "resource": scope, "token_type": "Bearer", }) transport = async_validating_transport(requests=[Request()] * 2, responses=[expected_response] * 2) with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, environ, clear=True): credential = ManagedIdentityCredential(transport=transport, raw_request_hook=request_hook, raw_response_hook=response_hook) await credential.get_token(scope) if environ: # some environment variables are set, so we're not mocking IMDS and should expect 1 request assert request_hook.call_count == 1 assert response_hook.call_count == 1 args, kwargs = response_hook.call_args pipeline_response = args[0] assert pipeline_response.http_response == expected_response else: # we're mocking IMDS and should expect 2 requests assert request_hook.call_count == 2 assert response_hook.call_count == 2 responses = [ args[0].http_response for args, _ in response_hook.call_args_list ] assert responses == [expected_response] * 2
async def main(req: func.HttpRequest) -> func.HttpResponse: # capture interesting environment variables for debugging env = "\n".join(f"{var}: {os.environ.get(var)}" for var in EXPECTED_VARIABLES) try: credential = ManagedIdentityCredential(client_id=os.environ.get( "AZURE_IDENTITY_TEST_MANAGED_IDENTITY_CLIENT_ID")) # do something with Key Vault to verify the credential can get a valid token client = SecretClient(os.environ["AZURE_IDENTITY_TEST_VAULT_URL"], credential, logging_enable=True) secret = await client.set_secret("managed-identity-test-secret", "value") await client.delete_secret(secret.name) return func.HttpResponse("test passed") except Exception as ex: return func.HttpResponse("test failed: " + repr(ex) + "\n" * 3 + env)
async def test_prefers_app_service_2017_09_01(): """When the environment is configured for both App Service versions, the credential should prefer 2017-09-01 Support for 2019-08-01 was removed due to https://github.com/Azure/azure-sdk-for-python/issues/14670. This test should be removed when that support is added back. """ access_token = "****" expires_on = 42 expected_token = AccessToken(access_token, expires_on) url = "http://localhost:42/token" secret = "expected-secret" scope = "scope" transport = async_validating_transport( requests=[ Request( url, method="GET", required_headers={ "secret": secret, "User-Agent": USER_AGENT }, required_params={ "api-version": "2017-09-01", "resource": scope }, ) ] * 2, responses=[ mock_response( json_payload={ "access_token": access_token, "expires_on": "01/01/1970 00:00:{} +00:00".format( expires_on), # linux format "resource": scope, "token_type": "Bearer", }), mock_response( json_payload={ "access_token": access_token, "expires_on": "1/1/1970 12:00:{} AM +00:00".format( expires_on), # windows format "resource": scope, "token_type": "Bearer", }), ], ) with mock.patch.dict( MANAGED_IDENTITY_ENVIRON, { EnvironmentVariables.IDENTITY_ENDPOINT: url, EnvironmentVariables.IDENTITY_HEADER: secret, EnvironmentVariables.MSI_ENDPOINT: url, EnvironmentVariables.MSI_SECRET: secret, }, clear=True, ): credential = ManagedIdentityCredential(transport=transport) token = await credential.get_token(scope) assert token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential(transport=transport) token = await credential.get_token(scope) assert token == expected_token assert token.expires_on == expires_on
async def test_context_manager_incomplete_configuration(): async with ManagedIdentityCredential(): pass