Ejemplo n.º 1
0
def test_serialization():
    expected = {
        "authority": "http://localhost",
        "clientId": "client-id",
        "homeAccountId": "object-id.tenant-id",
        "tenantId": "tenant-id",
        "username": "******",
        "version": "1.0",
    }

    record = AuthenticationRecord(
        expected["tenantId"],
        expected["clientId"],
        expected["authority"],
        expected["homeAccountId"],
        expected["username"],
    )
    serialized = record.serialize()

    assert json.loads(serialized) == expected

    deserialized = AuthenticationRecord.deserialize(serialized)

    assert sorted(vars(deserialized)) == sorted(vars(record))

    assert record.authority == deserialized.authority == expected["authority"]
    assert record.client_id == deserialized.client_id == expected["clientId"]
    assert record.home_account_id == deserialized.home_account_id == expected[
        "homeAccountId"]
    assert record.tenant_id == deserialized.tenant_id == expected["tenantId"]
    assert record.username == deserialized.username == expected["username"]
Ejemplo n.º 2
0
def test_tenant_id_validation():
    """The credential should raise ValueError when given an invalid tenant_id"""

    valid_ids = {
        "c878a2ab-8ef4-413b-83a0-199afb84d7fb", "contoso.onmicrosoft.com",
        "organizations", "common"
    }
    for tenant in valid_ids:
        record = AuthenticationRecord(tenant, "client-id", "authority",
                                      "home.account.id", "username")
        SharedTokenCacheCredential(authentication_record=record)
        SharedTokenCacheCredential(authentication_record=record,
                                   tenant_id=tenant)

    invalid_ids = {
        "", "my tenant", "my_tenant", "/", "\\", '"my-tenant"', "'my-tenant'"
    }
    for tenant in invalid_ids:
        record = AuthenticationRecord(tenant, "client-id", "authority",
                                      "home.account.id", "username")
        with pytest.raises(ValueError):
            SharedTokenCacheCredential(authentication_record=record)
        with pytest.raises(ValueError):
            SharedTokenCacheCredential(authentication_record=record,
                                       tenant_id=tenant)
Ejemplo n.º 3
0
def test_client_capabilities():
    """the credential should configure MSAL for capability CP1 unless AZURE_IDENTITY_DISABLE_CP1 is set"""
    def send(request, **_):
        # expecting only the discovery requests triggered by creating an msal.PublicClientApplication
        # because the cache is empty--the credential shouldn't send a token request
        return get_discovery_response("https://localhost/tenant")

    record = AuthenticationRecord("tenant-id", "client_id", "authority",
                                  "home_account_id", "username")
    transport = Mock(send=send)
    credential = SharedTokenCacheCredential(transport=transport,
                                            authentication_record=record,
                                            _cache=TokenCache())

    with patch("azure.identity._credentials.silent.PublicClientApplication"
               ) as PublicClientApplication:
        with pytest.raises(ClientAuthenticationError):  # (cache is empty)
            credential.get_token("scope")

    assert PublicClientApplication.call_count == 1
    _, kwargs = PublicClientApplication.call_args
    assert kwargs["client_capabilities"] == ["CP1"]

    credential = SharedTokenCacheCredential(transport=transport,
                                            authentication_record=record,
                                            _cache=TokenCache())
    with patch("azure.identity._credentials.silent.PublicClientApplication"
               ) as PublicClientApplication:
        with patch.dict("os.environ", {"AZURE_IDENTITY_DISABLE_CP1": "true"}):
            with pytest.raises(ClientAuthenticationError):  # (cache is empty)
                credential.get_token("scope")

    assert PublicClientApplication.call_count == 1
    _, kwargs = PublicClientApplication.call_args
    assert kwargs["client_capabilities"] is None
Ejemplo n.º 4
0
def test_authentication_record_no_match():
    tenant_id = "tenant-id"
    client_id = "client-id"
    authority = "localhost"
    object_id = "object-id"
    home_account_id = object_id + "." + tenant_id
    username = "******"
    record = AuthenticationRecord(tenant_id, client_id, authority,
                                  home_account_id, username)

    def send(request, **_):
        # expecting only MSAL discovery requests
        assert request.method == "GET"
        return get_discovery_response()

    cache = populated_cache(
        get_account_event(
            "not-" + username,
            "not-" + object_id,
            "different-" + tenant_id,
            client_id="not-" + client_id,
        ), )
    credential = SharedTokenCacheCredential(authentication_record=record,
                                            transport=Mock(send=send),
                                            _cache=cache)

    with pytest.raises(CredentialUnavailableError):
        credential.get_token("scope")
Ejemplo n.º 5
0
async def test_authentication_record_empty_cache():
    record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username")
    transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
    credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache())

    with pytest.raises(CredentialUnavailableError):
        await credential.get_token("scope")
Ejemplo n.º 6
0
def test_auth_record_multiple_accounts_for_username():
    tenant_id = "tenant-id"
    client_id = "client-id"
    authority = "localhost"
    object_id = "object-id"
    home_account_id = object_id + "." + tenant_id
    username = "******"
    record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

    expected_access_token = "****"
    expected_refresh_token = "**"
    expected_account = get_account_event(
        username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token
    )
    cache = populated_cache(
        expected_account,
        get_account_event(  # this account matches all but the record's tenant
            username,
            object_id,
            "different-" + tenant_id,
            authority=authority,
            client_id=client_id,
            refresh_token="not-" + expected_refresh_token,
        ),
    )

    transport = msal_validating_transport(
        endpoint="https://{}/{}".format(authority, tenant_id),
        requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
        responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
    )
    credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

    token = credential.get_token("scope")
    assert token.token == expected_access_token
Ejemplo n.º 7
0
def test_authentication_record_no_match():
    tenant_id = "tenant-id"
    client_id = "client-id"
    authority = "localhost"
    object_id = "object-id"
    home_account_id = object_id + "." + tenant_id
    username = "******"
    record = AuthenticationRecord(tenant_id, client_id, authority,
                                  home_account_id, username)

    transport = Mock(
        side_effect=Exception("the credential shouldn't send a request"))
    cache = populated_cache(
        get_account_event(
            "not-" + username,
            "not-" + object_id,
            "different-" + tenant_id,
            client_id="not-" + client_id,
        ), )
    credential = SharedTokenCacheCredential(authentication_record=record,
                                            transport=transport,
                                            _cache=cache)

    with pytest.raises(CredentialUnavailableError):
        credential.get_token("scope")
Ejemplo n.º 8
0
def test_claims_challenge():
    """get_token should pass any claims challenge to MSAL token acquisition APIs"""

    expected_claims = '{"access_token": {"essential": "true"}'

    record = AuthenticationRecord("tenant-id", "client_id", "authority",
                                  "home_account_id", "username")

    msal_app = Mock()
    msal_app.get_accounts.return_value = [{
        "home_account_id":
        record.home_account_id
    }]
    msal_app.acquire_token_silent_with_error.return_value = dict(
        build_aad_response(access_token="**", id_token=build_id_token()))

    transport = Mock(send=Mock(side_effect=Exception(
        "this test mocks MSAL, so no request should be sent")))
    credential = SharedTokenCacheCredential(transport=transport,
                                            authentication_record=record,
                                            _cache=TokenCache())
    with patch(
            SharedTokenCacheCredential.__module__ + ".PublicClientApplication",
            lambda *_, **__: msal_app):
        credential.get_token("scope", claims=expected_claims)

    assert msal_app.acquire_token_silent_with_error.call_count == 1
    args, kwargs = msal_app.acquire_token_silent_with_error.call_args
    assert kwargs["claims_challenge"] == expected_claims
Ejemplo n.º 9
0
 def _azure_interactive_auth(self):
     if self._authentication_record_json is None:
         _interactive_credential = DeviceCodeCredential(
             tenant_id=self._tenant_id,
             timeout=180,
             prompt_callback=None,
             _cache=load_persistent_cache(
                 TokenCachePersistenceOptions(
                     name=self._azure_cred_cache_name,
                     allow_unencrypted_storage=True,
                     cache_location=self._azure_cred_cache_location)))
         _auth_record = _interactive_credential.authenticate()
         self._authentication_record_json = _auth_record.serialize()
     else:
         deserialized_auth_record = AuthenticationRecord.deserialize(
             self._authentication_record_json)
         _interactive_credential = DeviceCodeCredential(
             tenant_id=self._tenant_id,
             timeout=180,
             prompt_callback=None,
             _cache=load_persistent_cache(
                 TokenCachePersistenceOptions(
                     name=self._azure_cred_cache_name,
                     allow_unencrypted_storage=True,
                     cache_location=self._azure_cred_cache_location)),
             authentication_record=deserialized_auth_record)
     return _interactive_credential
Ejemplo n.º 10
0
def test_client_capabilities():
    """the credential should configure MSAL for capability CP1 unless AZURE_IDENTITY_DISABLE_CP1 is set"""

    record = AuthenticationRecord("tenant-id", "client_id", "authority", "home_account_id", "username")
    transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent")))
    credential = SharedTokenCacheCredential(
        transport=transport, authentication_record=record, _cache=TokenCache()
    )

    with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication") as PublicClientApplication:
        credential._initialize()

    assert PublicClientApplication.call_count == 1
    _, kwargs = PublicClientApplication.call_args
    assert kwargs["client_capabilities"] == ["CP1"]

    credential = SharedTokenCacheCredential(
        transport=transport, authentication_record=record, _cache=TokenCache()
    )
    with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication") as PublicClientApplication:
        with patch.dict("os.environ", {"AZURE_IDENTITY_DISABLE_CP1": "true"}):
            credential._initialize()

    assert PublicClientApplication.call_count == 1
    _, kwargs = PublicClientApplication.call_args
    assert kwargs["client_capabilities"] is None
Ejemplo n.º 11
0
def test_disable_automatic_authentication():
    """When silent auth fails the credential should raise, if it's configured not to authenticate automatically"""

    expected_details = "something went wrong"
    record = AuthenticationRecord("tenant-id", "client-id", "localhost",
                                  "object.tenant", "username")
    msal_app = Mock(
        acquire_token_silent_with_error=Mock(
            return_value={"error_description": expected_details}),
        get_accounts=Mock(
            return_value=[{
                "home_account_id": record.home_account_id
            }]),
    )

    credential = MockCredential(
        authentication_record=record,
        disable_automatic_authentication=True,
        request_token=Mock(side_effect=Exception(
            "credential shouldn't begin interactive authentication")),
    )

    scope = "scope"
    expected_claims = "..."
    with pytest.raises(AuthenticationRequiredError) as ex:
        with patch("msal.PublicClientApplication", lambda *_, **__: msal_app):
            credential.get_token(scope, claims=expected_claims)

    # the exception should carry the requested scopes and claims, and any error message from AAD
    assert ex.value.scopes == (scope, )
    assert ex.value.claims == expected_claims
Ejemplo n.º 12
0
def test_unknown_version(version):
    """deserialize should raise ValueError when the data doesn't contain a known version"""

    data = {
        "authority": "http://localhost",
        "clientId": "client-id",
        "homeAccountId": "object-id.tenant-id",
        "tenantId": "tenant-id",
        "username": "******",
    }

    if version:
        data["version"] = version

    with pytest.raises(ValueError, match=".*{}.*".format(version)) as ex:
        AuthenticationRecord.deserialize(json.dumps(data))
    assert str(SUPPORTED_VERSIONS) in str(ex.value)
Ejemplo n.º 13
0
    def get_auth_record(self) -> AuthenticationRecord:
        result = None

        try:
            with open(AUTH_RECORD_LOCATION, 'r') as file:
                result = file.read()
            return AuthenticationRecord.deserialize(result)
        except IOError as ex:
            raise CLIException('Login to run this command') from ex
Ejemplo n.º 14
0
    def _save_auth_record(self, auth_record: AuthenticationRecord):
        record = auth_record.serialize()

        try:
            with open(AUTH_RECORD_LOCATION, 'w') as file:
                file.write(record)
        except IOError as ex:
            raise CLIException(
                'Authentication session not saved, you\'ll be prompted \
                to login when running a command') from ex
Ejemplo n.º 15
0
def _get_cache_args(token_path: Path):
    cache_args = {
        'cache_persistence_options':
        TokenCachePersistenceOptions(name='parsedmarc')
    }
    auth_record = _load_token(token_path)
    if auth_record:
        cache_args['authentication_record'] = \
            AuthenticationRecord.deserialize(auth_record)
    return cache_args
Ejemplo n.º 16
0
def test_multitenant_authentication_auth_record():
    default_tenant = "organizations"
    first_token = "***"
    second_tenant = "second-tenant"
    second_token = first_token * 2

    authority = AzureAuthorityHosts.AZURE_PUBLIC_CLOUD
    object_id = "object-id"
    home_account_id = object_id + "." + default_tenant
    record = AuthenticationRecord(default_tenant, "client-id", authority,
                                  home_account_id, "user")

    def send(request, **_):
        parsed = urlparse(request.url)
        tenant_id = parsed.path.split("/")[1]
        if "/oauth2/v2.0/token" not in request.url:
            return get_discovery_response("https://{}/{}".format(
                parsed.netloc, tenant_id))

        assert tenant_id in (
            default_tenant,
            second_tenant), 'unexpected tenant "{}"'.format(tenant_id)
        return mock_response(json_payload=build_aad_response(
            access_token=second_token if tenant_id ==
            second_tenant else first_token,
            id_token_claims=id_token_claims(aud="...", iss="...", sub="..."),
        ))

    expected_account = get_account_event(record.username,
                                         object_id,
                                         record.tenant_id,
                                         client_id=record.client_id,
                                         refresh_token="**")
    cache = populated_cache(expected_account)

    credential = SharedTokenCacheCredential(
        authority=authority,
        transport=Mock(send=send),
        authentication_record=record,
        _cache=cache,
    )
    token = credential.get_token("scope")
    assert token.token == first_token

    token = credential.get_token("scope", tenant_id=default_tenant)
    assert token.token == first_token

    token = credential.get_token("scope", tenant_id=second_tenant)
    assert token.token == second_token

    # should still default to the first tenant
    token = credential.get_token("scope")
    assert token.token == first_token
Ejemplo n.º 17
0
def test_client_capabilities():
    """the credential should configure MSAL for capability CP1 (ability to handle claims challenges)"""

    record = AuthenticationRecord("tenant-id", "client_id", "authority", "home_account_id", "username")
    transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent")))
    credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache())

    with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication") as PublicClientApplication:
        credential._initialize()

    assert PublicClientApplication.call_count == 1
    _, kwargs = PublicClientApplication.call_args
    assert kwargs["client_capabilities"] == ["CP1"]
Ejemplo n.º 18
0
def test_serialization():
    """serialize should accept arbitrary additional key/value pairs, which deserialize should ignore"""

    attrs = ("authority", "client_id", "home_account_id", "tenant_id",
             "username")
    nums = (n for n in range(len(attrs)))
    record_values = {attr: next(nums) for attr in attrs}

    record = AuthenticationRecord(**record_values)
    serialized = record.serialize()

    # AuthenticationRecord's fields should have been serialized
    assert json.loads(serialized) == record_values

    deserialized = AuthenticationRecord.deserialize(serialized)

    # the deserialized record and the constructed record should have the same fields
    assert sorted(vars(deserialized)) == sorted(vars(record))

    # the constructed and deserialized records should have the same values
    assert all(
        getattr(deserialized, attr) == record_values[attr] for attr in attrs)
Ejemplo n.º 19
0
def test_authentication_record_empty_cache():
    record = AuthenticationRecord("tenant-id", "client_id", "authority", "home_account_id", "username")

    def send(request, **_):
        # expecting only MSAL discovery requests
        assert request.method == "GET"
        return get_discovery_response()

    credential = SharedTokenCacheCredential(
        authentication_record=record, transport=Mock(send=send), _cache=TokenCache()
    )

    with pytest.raises(CredentialUnavailableError):
        credential.get_token("scope")
def test_authentication_record_authenticating_tenant():
    """when given a record and 'tenant_id', the credential should authenticate in the latter"""

    expected_tenant_id = "tenant-id"
    record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...",
                                  "...", "...")

    with patch.object(SharedTokenCacheCredential,
                      "_get_auth_client") as get_auth_client:
        SharedTokenCacheCredential(authentication_record=record,
                                   _cache=TokenCache(),
                                   tenant_id=expected_tenant_id)

    assert get_auth_client.call_count == 1
    _, kwargs = get_auth_client.call_args
    assert kwargs["tenant_id"] == expected_tenant_id
Ejemplo n.º 21
0
async def test_authentication_record_authenticating_tenant():
    """when given a record and 'tenant_id', the credential should authenticate in the latter"""

    expected_tenant_id = "tenant-id"
    record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...")

    with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client:
        credential = SharedTokenCacheCredential(
            authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id
        )
        with pytest.raises(CredentialUnavailableError):
            # this raises because the cache is empty
            await credential.get_token("scope")

    assert get_auth_client.call_count == 1
    _, kwargs = get_auth_client.call_args
    assert kwargs["tenant_id"] == expected_tenant_id
def test_get_token_wraps_exceptions():
    """get_token shouldn't propagate exceptions from MSAL"""

    class CustomException(Exception):
        pass

    expected_message = "something went wrong"
    record = AuthenticationRecord("tenant-id", "client-id", "localhost", "object.tenant", "username")
    msal_app = Mock(
        acquire_token_silent_with_error=Mock(side_effect=CustomException(expected_message)),
        get_accounts=Mock(return_value=[{"home_account_id": record.home_account_id}]),
    )
    credential = MockCredential(msal_app_factory=lambda *_, **__: msal_app, authentication_record=record)
    with pytest.raises(ClientAuthenticationError) as ex:
        credential.get_token("scope")

    assert expected_message in ex.value.message
    assert msal_app.acquire_token_silent_with_error.call_count == 1, "credential didn't attempt silent auth"
def test_authentication_record_argument():
    """The credential should initialize its msal.ClientApplication with values from a given record"""

    record = AuthenticationRecord("tenant-id", "client-id", "localhost", "object.tenant", "username")

    def validate_app_parameters(authority, client_id, **_):
        # the 'authority' argument to msal.ClientApplication should be a URL of the form https://authority/tenant
        assert authority == "https://{}/{}".format(record.authority, record.tenant_id)
        assert client_id == record.client_id
        return Mock(get_accounts=Mock(return_value=[]))

    app_factory = Mock(wraps=validate_app_parameters)
    credential = MockCredential(
        authentication_record=record, disable_automatic_authentication=True, msal_app_factory=app_factory,
    )
    with pytest.raises(AuthenticationRequiredError):
        credential.get_token("scope")

    assert app_factory.call_count == 1, "credential didn't create an msal application"
Ejemplo n.º 24
0
def test_authentication_record_authenticating_tenant():
    """when given a record and 'tenant_id', the credential should authenticate in the latter"""

    expected_tenant_id = "tenant-id"
    record = AuthenticationRecord("not- " + expected_tenant_id, "...", "localhost", "...", "...")

    def mock_send(request, **_):
        if not request.body:
            return get_discovery_response()
        assert request.url.startswith("https://localhost/" + expected_tenant_id)
        return mock_response(json_payload=build_aad_response(access_token="*"))

    transport = Mock(send=Mock(wraps=mock_send))
    credential = SharedTokenCacheCredential(
        authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id, transport=transport
    )
    with pytest.raises(CredentialUnavailableError):
        credential.get_token("scope")  # this raises because the cache is empty

    assert transport.send.called
Ejemplo n.º 25
0
def test_tenant_argument_overrides_record():
    """The 'tenant_ic' keyword argument should override a given record's value"""

    tenant_id = "some-guid"
    authority = "localhost"
    record = AuthenticationRecord(tenant_id, "client-id", authority,
                                  "object.tenant", "username")

    expected_tenant = tenant_id[::-1]
    expected_authority = "https://{}/{}".format(authority, expected_tenant)

    def validate_authority(authority, **_):
        assert authority == expected_authority
        return Mock(get_accounts=Mock(return_value=[]))

    credential = MockCredential(authentication_record=record,
                                tenant_id=expected_tenant,
                                disable_automatic_authentication=True)
    with pytest.raises(AuthenticationRequiredError):
        with patch("msal.PublicClientApplication", validate_authority):
            credential.get_token("scope")
Ejemplo n.º 26
0
def test_authentication_record():
    tenant_id = "tenant-id"
    client_id = "client-id"
    authority = "localhost"
    object_id = "object-id"
    home_account_id = object_id + "." + tenant_id
    username = "******"
    record = AuthenticationRecord(tenant_id, client_id, authority,
                                  home_account_id, username)

    expected_access_token = "****"
    expected_refresh_token = "**"
    account = get_account_event(username,
                                object_id,
                                tenant_id,
                                authority=authority,
                                client_id=client_id,
                                refresh_token=expected_refresh_token)
    cache = populated_cache(account)

    transport = validating_transport(
        requests=[
            Request(authority=authority,
                    required_data={"refresh_token": expected_refresh_token})
        ],
        responses=[
            mock_response(json_payload=build_aad_response(
                access_token=expected_access_token))
        ],
    )
    credential = SharedTokenCacheCredential(authentication_record=record,
                                            transport=transport,
                                            _cache=cache)

    token = credential.get_token("scope")
    assert token.token == expected_access_token
record = credential.authenticate()
print("\nAuthenticated first credential")

# The record contains no authentication secrets. You can serialize it to JSON for storage.
record_json = record.serialize()

# An authenticated credential is ready for use with a client. This request should succeed
# without prompting for authentication again.
client = SecretClient(VAULT_URL, credential)
secret_names = [s.name for s in client.list_properties_of_secrets()]
print("\nCompleted request with first credential")

# An authentication record stored by your application enables other credentials to access data from
# past authentications. If the cache contains sufficient data, this eliminates the need for your
# application to prompt for authentication every time it runs.
deserialized_record = AuthenticationRecord.deserialize(record_json)
new_credential = InteractiveBrowserCredential(
    cache_persistence_options=TokenCachePersistenceOptions(),
    authentication_record=deserialized_record)

# This request should also succeed without prompting for authentication.
client = SecretClient(VAULT_URL, new_credential)
secret_names = [s.name for s in client.list_properties_of_secrets()]
print("\nCompleted request with credential using shared cache")

# To isolate the token cache from other applications, you can provide a cache name to TokenCachePersistenceOptions.
separate_cache_credential = InteractiveBrowserCredential(
    cache_persistence_options=TokenCachePersistenceOptions(name="my_app"),
    authentication_record=deserialized_record)

# This request should prompt for authentication since the credential is using a separate cache.
Ejemplo n.º 28
0
def _cache_auth_record(record: AuthenticationRecord, token_path: Path):
    token = record.serialize()
    with token_path.open('w') as token_file:
        token_file.write(token)