def test_serialization(): expected = { "authority": "http://localhost", "clientId": "client-id", "homeAccountId": "object-id.tenant-id", "tenantId": "tenant-id", "username": "******", "version": "1.0", } record = AuthenticationRecord( expected["tenantId"], expected["clientId"], expected["authority"], expected["homeAccountId"], expected["username"], ) serialized = record.serialize() assert json.loads(serialized) == expected deserialized = AuthenticationRecord.deserialize(serialized) assert sorted(vars(deserialized)) == sorted(vars(record)) assert record.authority == deserialized.authority == expected["authority"] assert record.client_id == deserialized.client_id == expected["clientId"] assert record.home_account_id == deserialized.home_account_id == expected[ "homeAccountId"] assert record.tenant_id == deserialized.tenant_id == expected["tenantId"] assert record.username == deserialized.username == expected["username"]
def test_tenant_id_validation(): """The credential should raise ValueError when given an invalid tenant_id""" valid_ids = { "c878a2ab-8ef4-413b-83a0-199afb84d7fb", "contoso.onmicrosoft.com", "organizations", "common" } for tenant in valid_ids: record = AuthenticationRecord(tenant, "client-id", "authority", "home.account.id", "username") SharedTokenCacheCredential(authentication_record=record) SharedTokenCacheCredential(authentication_record=record, tenant_id=tenant) invalid_ids = { "", "my tenant", "my_tenant", "/", "\\", '"my-tenant"', "'my-tenant'" } for tenant in invalid_ids: record = AuthenticationRecord(tenant, "client-id", "authority", "home.account.id", "username") with pytest.raises(ValueError): SharedTokenCacheCredential(authentication_record=record) with pytest.raises(ValueError): SharedTokenCacheCredential(authentication_record=record, tenant_id=tenant)
def test_client_capabilities(): """the credential should configure MSAL for capability CP1 unless AZURE_IDENTITY_DISABLE_CP1 is set""" def send(request, **_): # expecting only the discovery requests triggered by creating an msal.PublicClientApplication # because the cache is empty--the credential shouldn't send a token request return get_discovery_response("https://localhost/tenant") record = AuthenticationRecord("tenant-id", "client_id", "authority", "home_account_id", "username") transport = Mock(send=send) credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache()) with patch("azure.identity._credentials.silent.PublicClientApplication" ) as PublicClientApplication: with pytest.raises(ClientAuthenticationError): # (cache is empty) credential.get_token("scope") assert PublicClientApplication.call_count == 1 _, kwargs = PublicClientApplication.call_args assert kwargs["client_capabilities"] == ["CP1"] credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache()) with patch("azure.identity._credentials.silent.PublicClientApplication" ) as PublicClientApplication: with patch.dict("os.environ", {"AZURE_IDENTITY_DISABLE_CP1": "true"}): with pytest.raises(ClientAuthenticationError): # (cache is empty) credential.get_token("scope") assert PublicClientApplication.call_count == 1 _, kwargs = PublicClientApplication.call_args assert kwargs["client_capabilities"] is None
def test_authentication_record_no_match(): tenant_id = "tenant-id" client_id = "client-id" authority = "localhost" object_id = "object-id" home_account_id = object_id + "." + tenant_id username = "******" record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) def send(request, **_): # expecting only MSAL discovery requests assert request.method == "GET" return get_discovery_response() cache = populated_cache( get_account_event( "not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id, ), ) credential = SharedTokenCacheCredential(authentication_record=record, transport=Mock(send=send), _cache=cache) with pytest.raises(CredentialUnavailableError): credential.get_token("scope")
async def test_authentication_record_empty_cache(): record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username") transport = Mock(side_effect=Exception("the credential shouldn't send a request")) credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache()) with pytest.raises(CredentialUnavailableError): await credential.get_token("scope")
def test_auth_record_multiple_accounts_for_username(): tenant_id = "tenant-id" client_id = "client-id" authority = "localhost" object_id = "object-id" home_account_id = object_id + "." + tenant_id username = "******" record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) expected_access_token = "****" expected_refresh_token = "**" expected_account = get_account_event( username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token ) cache = populated_cache( expected_account, get_account_event( # this account matches all but the record's tenant username, object_id, "different-" + tenant_id, authority=authority, client_id=client_id, refresh_token="not-" + expected_refresh_token, ), ) transport = msal_validating_transport( endpoint="https://{}/{}".format(authority, tenant_id), requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})], responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], ) credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) token = credential.get_token("scope") assert token.token == expected_access_token
def test_authentication_record_no_match(): tenant_id = "tenant-id" client_id = "client-id" authority = "localhost" object_id = "object-id" home_account_id = object_id + "." + tenant_id username = "******" record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) transport = Mock( side_effect=Exception("the credential shouldn't send a request")) cache = populated_cache( get_account_event( "not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id, ), ) credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) with pytest.raises(CredentialUnavailableError): credential.get_token("scope")
def test_claims_challenge(): """get_token should pass any claims challenge to MSAL token acquisition APIs""" expected_claims = '{"access_token": {"essential": "true"}' record = AuthenticationRecord("tenant-id", "client_id", "authority", "home_account_id", "username") msal_app = Mock() msal_app.get_accounts.return_value = [{ "home_account_id": record.home_account_id }] msal_app.acquire_token_silent_with_error.return_value = dict( build_aad_response(access_token="**", id_token=build_id_token())) transport = Mock(send=Mock(side_effect=Exception( "this test mocks MSAL, so no request should be sent"))) credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache()) with patch( SharedTokenCacheCredential.__module__ + ".PublicClientApplication", lambda *_, **__: msal_app): credential.get_token("scope", claims=expected_claims) assert msal_app.acquire_token_silent_with_error.call_count == 1 args, kwargs = msal_app.acquire_token_silent_with_error.call_args assert kwargs["claims_challenge"] == expected_claims
def _azure_interactive_auth(self): if self._authentication_record_json is None: _interactive_credential = DeviceCodeCredential( tenant_id=self._tenant_id, timeout=180, prompt_callback=None, _cache=load_persistent_cache( TokenCachePersistenceOptions( name=self._azure_cred_cache_name, allow_unencrypted_storage=True, cache_location=self._azure_cred_cache_location))) _auth_record = _interactive_credential.authenticate() self._authentication_record_json = _auth_record.serialize() else: deserialized_auth_record = AuthenticationRecord.deserialize( self._authentication_record_json) _interactive_credential = DeviceCodeCredential( tenant_id=self._tenant_id, timeout=180, prompt_callback=None, _cache=load_persistent_cache( TokenCachePersistenceOptions( name=self._azure_cred_cache_name, allow_unencrypted_storage=True, cache_location=self._azure_cred_cache_location)), authentication_record=deserialized_auth_record) return _interactive_credential
def test_client_capabilities(): """the credential should configure MSAL for capability CP1 unless AZURE_IDENTITY_DISABLE_CP1 is set""" record = AuthenticationRecord("tenant-id", "client_id", "authority", "home_account_id", "username") transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent"))) credential = SharedTokenCacheCredential( transport=transport, authentication_record=record, _cache=TokenCache() ) with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication") as PublicClientApplication: credential._initialize() assert PublicClientApplication.call_count == 1 _, kwargs = PublicClientApplication.call_args assert kwargs["client_capabilities"] == ["CP1"] credential = SharedTokenCacheCredential( transport=transport, authentication_record=record, _cache=TokenCache() ) with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication") as PublicClientApplication: with patch.dict("os.environ", {"AZURE_IDENTITY_DISABLE_CP1": "true"}): credential._initialize() assert PublicClientApplication.call_count == 1 _, kwargs = PublicClientApplication.call_args assert kwargs["client_capabilities"] is None
def test_disable_automatic_authentication(): """When silent auth fails the credential should raise, if it's configured not to authenticate automatically""" expected_details = "something went wrong" record = AuthenticationRecord("tenant-id", "client-id", "localhost", "object.tenant", "username") msal_app = Mock( acquire_token_silent_with_error=Mock( return_value={"error_description": expected_details}), get_accounts=Mock( return_value=[{ "home_account_id": record.home_account_id }]), ) credential = MockCredential( authentication_record=record, disable_automatic_authentication=True, request_token=Mock(side_effect=Exception( "credential shouldn't begin interactive authentication")), ) scope = "scope" expected_claims = "..." with pytest.raises(AuthenticationRequiredError) as ex: with patch("msal.PublicClientApplication", lambda *_, **__: msal_app): credential.get_token(scope, claims=expected_claims) # the exception should carry the requested scopes and claims, and any error message from AAD assert ex.value.scopes == (scope, ) assert ex.value.claims == expected_claims
def test_unknown_version(version): """deserialize should raise ValueError when the data doesn't contain a known version""" data = { "authority": "http://localhost", "clientId": "client-id", "homeAccountId": "object-id.tenant-id", "tenantId": "tenant-id", "username": "******", } if version: data["version"] = version with pytest.raises(ValueError, match=".*{}.*".format(version)) as ex: AuthenticationRecord.deserialize(json.dumps(data)) assert str(SUPPORTED_VERSIONS) in str(ex.value)
def get_auth_record(self) -> AuthenticationRecord: result = None try: with open(AUTH_RECORD_LOCATION, 'r') as file: result = file.read() return AuthenticationRecord.deserialize(result) except IOError as ex: raise CLIException('Login to run this command') from ex
def _save_auth_record(self, auth_record: AuthenticationRecord): record = auth_record.serialize() try: with open(AUTH_RECORD_LOCATION, 'w') as file: file.write(record) except IOError as ex: raise CLIException( 'Authentication session not saved, you\'ll be prompted \ to login when running a command') from ex
def _get_cache_args(token_path: Path): cache_args = { 'cache_persistence_options': TokenCachePersistenceOptions(name='parsedmarc') } auth_record = _load_token(token_path) if auth_record: cache_args['authentication_record'] = \ AuthenticationRecord.deserialize(auth_record) return cache_args
def test_multitenant_authentication_auth_record(): default_tenant = "organizations" first_token = "***" second_tenant = "second-tenant" second_token = first_token * 2 authority = AzureAuthorityHosts.AZURE_PUBLIC_CLOUD object_id = "object-id" home_account_id = object_id + "." + default_tenant record = AuthenticationRecord(default_tenant, "client-id", authority, home_account_id, "user") def send(request, **_): parsed = urlparse(request.url) tenant_id = parsed.path.split("/")[1] if "/oauth2/v2.0/token" not in request.url: return get_discovery_response("https://{}/{}".format( parsed.netloc, tenant_id)) assert tenant_id in ( default_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant_id) return mock_response(json_payload=build_aad_response( access_token=second_token if tenant_id == second_tenant else first_token, id_token_claims=id_token_claims(aud="...", iss="...", sub="..."), )) expected_account = get_account_event(record.username, object_id, record.tenant_id, client_id=record.client_id, refresh_token="**") cache = populated_cache(expected_account) credential = SharedTokenCacheCredential( authority=authority, transport=Mock(send=send), authentication_record=record, _cache=cache, ) token = credential.get_token("scope") assert token.token == first_token token = credential.get_token("scope", tenant_id=default_tenant) assert token.token == first_token token = credential.get_token("scope", tenant_id=second_tenant) assert token.token == second_token # should still default to the first tenant token = credential.get_token("scope") assert token.token == first_token
def test_client_capabilities(): """the credential should configure MSAL for capability CP1 (ability to handle claims challenges)""" record = AuthenticationRecord("tenant-id", "client_id", "authority", "home_account_id", "username") transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent"))) credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache()) with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication") as PublicClientApplication: credential._initialize() assert PublicClientApplication.call_count == 1 _, kwargs = PublicClientApplication.call_args assert kwargs["client_capabilities"] == ["CP1"]
def test_serialization(): """serialize should accept arbitrary additional key/value pairs, which deserialize should ignore""" attrs = ("authority", "client_id", "home_account_id", "tenant_id", "username") nums = (n for n in range(len(attrs))) record_values = {attr: next(nums) for attr in attrs} record = AuthenticationRecord(**record_values) serialized = record.serialize() # AuthenticationRecord's fields should have been serialized assert json.loads(serialized) == record_values deserialized = AuthenticationRecord.deserialize(serialized) # the deserialized record and the constructed record should have the same fields assert sorted(vars(deserialized)) == sorted(vars(record)) # the constructed and deserialized records should have the same values assert all( getattr(deserialized, attr) == record_values[attr] for attr in attrs)
def test_authentication_record_empty_cache(): record = AuthenticationRecord("tenant-id", "client_id", "authority", "home_account_id", "username") def send(request, **_): # expecting only MSAL discovery requests assert request.method == "GET" return get_discovery_response() credential = SharedTokenCacheCredential( authentication_record=record, transport=Mock(send=send), _cache=TokenCache() ) with pytest.raises(CredentialUnavailableError): credential.get_token("scope")
def test_authentication_record_authenticating_tenant(): """when given a record and 'tenant_id', the credential should authenticate in the latter""" expected_tenant_id = "tenant-id" record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...") with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client: SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id) assert get_auth_client.call_count == 1 _, kwargs = get_auth_client.call_args assert kwargs["tenant_id"] == expected_tenant_id
async def test_authentication_record_authenticating_tenant(): """when given a record and 'tenant_id', the credential should authenticate in the latter""" expected_tenant_id = "tenant-id" record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...") with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client: credential = SharedTokenCacheCredential( authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id ) with pytest.raises(CredentialUnavailableError): # this raises because the cache is empty await credential.get_token("scope") assert get_auth_client.call_count == 1 _, kwargs = get_auth_client.call_args assert kwargs["tenant_id"] == expected_tenant_id
def test_get_token_wraps_exceptions(): """get_token shouldn't propagate exceptions from MSAL""" class CustomException(Exception): pass expected_message = "something went wrong" record = AuthenticationRecord("tenant-id", "client-id", "localhost", "object.tenant", "username") msal_app = Mock( acquire_token_silent_with_error=Mock(side_effect=CustomException(expected_message)), get_accounts=Mock(return_value=[{"home_account_id": record.home_account_id}]), ) credential = MockCredential(msal_app_factory=lambda *_, **__: msal_app, authentication_record=record) with pytest.raises(ClientAuthenticationError) as ex: credential.get_token("scope") assert expected_message in ex.value.message assert msal_app.acquire_token_silent_with_error.call_count == 1, "credential didn't attempt silent auth"
def test_authentication_record_argument(): """The credential should initialize its msal.ClientApplication with values from a given record""" record = AuthenticationRecord("tenant-id", "client-id", "localhost", "object.tenant", "username") def validate_app_parameters(authority, client_id, **_): # the 'authority' argument to msal.ClientApplication should be a URL of the form https://authority/tenant assert authority == "https://{}/{}".format(record.authority, record.tenant_id) assert client_id == record.client_id return Mock(get_accounts=Mock(return_value=[])) app_factory = Mock(wraps=validate_app_parameters) credential = MockCredential( authentication_record=record, disable_automatic_authentication=True, msal_app_factory=app_factory, ) with pytest.raises(AuthenticationRequiredError): credential.get_token("scope") assert app_factory.call_count == 1, "credential didn't create an msal application"
def test_authentication_record_authenticating_tenant(): """when given a record and 'tenant_id', the credential should authenticate in the latter""" expected_tenant_id = "tenant-id" record = AuthenticationRecord("not- " + expected_tenant_id, "...", "localhost", "...", "...") def mock_send(request, **_): if not request.body: return get_discovery_response() assert request.url.startswith("https://localhost/" + expected_tenant_id) return mock_response(json_payload=build_aad_response(access_token="*")) transport = Mock(send=Mock(wraps=mock_send)) credential = SharedTokenCacheCredential( authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id, transport=transport ) with pytest.raises(CredentialUnavailableError): credential.get_token("scope") # this raises because the cache is empty assert transport.send.called
def test_tenant_argument_overrides_record(): """The 'tenant_ic' keyword argument should override a given record's value""" tenant_id = "some-guid" authority = "localhost" record = AuthenticationRecord(tenant_id, "client-id", authority, "object.tenant", "username") expected_tenant = tenant_id[::-1] expected_authority = "https://{}/{}".format(authority, expected_tenant) def validate_authority(authority, **_): assert authority == expected_authority return Mock(get_accounts=Mock(return_value=[])) credential = MockCredential(authentication_record=record, tenant_id=expected_tenant, disable_automatic_authentication=True) with pytest.raises(AuthenticationRequiredError): with patch("msal.PublicClientApplication", validate_authority): credential.get_token("scope")
def test_authentication_record(): tenant_id = "tenant-id" client_id = "client-id" authority = "localhost" object_id = "object-id" home_account_id = object_id + "." + tenant_id username = "******" record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) expected_access_token = "****" expected_refresh_token = "**" account = get_account_event(username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token) cache = populated_cache(account) transport = validating_transport( requests=[ Request(authority=authority, required_data={"refresh_token": expected_refresh_token}) ], responses=[ mock_response(json_payload=build_aad_response( access_token=expected_access_token)) ], ) credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) token = credential.get_token("scope") assert token.token == expected_access_token
record = credential.authenticate() print("\nAuthenticated first credential") # The record contains no authentication secrets. You can serialize it to JSON for storage. record_json = record.serialize() # An authenticated credential is ready for use with a client. This request should succeed # without prompting for authentication again. client = SecretClient(VAULT_URL, credential) secret_names = [s.name for s in client.list_properties_of_secrets()] print("\nCompleted request with first credential") # An authentication record stored by your application enables other credentials to access data from # past authentications. If the cache contains sufficient data, this eliminates the need for your # application to prompt for authentication every time it runs. deserialized_record = AuthenticationRecord.deserialize(record_json) new_credential = InteractiveBrowserCredential( cache_persistence_options=TokenCachePersistenceOptions(), authentication_record=deserialized_record) # This request should also succeed without prompting for authentication. client = SecretClient(VAULT_URL, new_credential) secret_names = [s.name for s in client.list_properties_of_secrets()] print("\nCompleted request with credential using shared cache") # To isolate the token cache from other applications, you can provide a cache name to TokenCachePersistenceOptions. separate_cache_credential = InteractiveBrowserCredential( cache_persistence_options=TokenCachePersistenceOptions(name="my_app"), authentication_record=deserialized_record) # This request should prompt for authentication since the credential is using a separate cache.
def _cache_auth_record(record: AuthenticationRecord, token_path: Path): token = record.serialize() with token_path.open('w') as token_file: token_file.write(token)