def test_persistent_cache_multiple_clients():
    """the credential shouldn't use tokens issued to other service principals"""

    access_token_a = "token a"
    access_token_b = "not " + access_token_a
    transport_a = msal_validating_transport(
        requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))]
    )
    transport_b = msal_validating_transport(
        requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_b))]
    )

    cache = TokenCache()
    with patch("azure.identity._internal.persistent_cache._load_persistent_cache") as mock_cache_loader:
        mock_cache_loader.return_value = Mock(wraps=cache)
        credential_a = ClientSecretCredential(
            "tenant-id", "client-a", "...", _enable_persistent_cache=True, transport=transport_a
        )
        assert mock_cache_loader.call_count == 1, "credential should load the persistent cache"
        credential_b = ClientSecretCredential(
            "tenant-id", "client-b", "...", _enable_persistent_cache=True, transport=transport_b
        )
        assert mock_cache_loader.call_count == 2, "credential should load the persistent cache"

    # A caches a token
    scope = "scope"
    token_a = credential_a.get_token(scope)
    assert token_a.token == access_token_a
    assert transport_a.send.call_count == 3  # two MSAL discovery requests, one token request

    # B should get a different token for the same scope
    token_b = credential_b.get_token(scope)
    assert token_b.token == access_token_b
    assert transport_b.send.call_count == 3
示例#2
0
def test_persistent_cache_multiple_clients(cert_path, cert_password):
    """the credential shouldn't use tokens issued to other service principals"""

    access_token_a = "token a"
    access_token_b = "not " + access_token_a
    transport_a = msal_validating_transport(
        requests=[Request()],
        responses=[
            mock_response(json_payload=build_aad_response(
                access_token=access_token_a))
        ])
    transport_b = msal_validating_transport(
        requests=[Request()],
        responses=[
            mock_response(json_payload=build_aad_response(
                access_token=access_token_b))
        ])

    cache = TokenCache()
    with patch(
            "azure.identity._internal.msal_credentials._load_persistent_cache"
    ) as mock_cache_loader:
        mock_cache_loader.return_value = Mock(wraps=cache)
        credential_a = CertificateCredential(
            "tenant",
            "client-a",
            cert_path,
            password=cert_password,
            transport=transport_a,
            cache_persistence_options=TokenCachePersistenceOptions(),
        )
        assert mock_cache_loader.call_count == 1, "credential should load the persistent cache"

        credential_b = CertificateCredential(
            "tenant",
            "client-b",
            cert_path,
            password=cert_password,
            transport=transport_b,
            cache_persistence_options=TokenCachePersistenceOptions(),
        )
        assert mock_cache_loader.call_count == 2, "credential should load the persistent cache"

    # A caches a token
    scope = "scope"
    token_a = credential_a.get_token(scope)
    assert token_a.token == access_token_a
    assert transport_a.send.call_count == 3  # two MSAL discovery requests, one token request

    # B should get a different token for the same scope
    token_b = credential_b.get_token(scope)
    assert token_b.token == access_token_b
    assert transport_b.send.call_count == 3

    assert len(cache.find(TokenCache.CredentialType.ACCESS_TOKEN)) == 2
示例#3
0
def test_client_secret_credential():
    client_id = "fake-client-id"
    secret = "fake-client-secret"
    tenant_id = "fake-tenant-id"
    access_token = "***"

    transport = msal_validating_transport(
        endpoint="https://localhost/" + tenant_id,
        requests=[
            Request(url_substring=tenant_id,
                    required_data={
                        "client_id": client_id,
                        "client_secret": secret
                    })
        ],
        responses=[
            mock_response(json_payload=build_aad_response(
                access_token=access_token))
        ],
    )

    token = ClientSecretCredential(tenant_id,
                                   client_id,
                                   secret,
                                   transport=transport).get_token("scope")

    assert token.token == access_token
示例#4
0
def test_auth_record_multiple_accounts_for_username():
    tenant_id = "tenant-id"
    client_id = "client-id"
    authority = "localhost"
    object_id = "object-id"
    home_account_id = object_id + "." + tenant_id
    username = "******"
    record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

    expected_access_token = "****"
    expected_refresh_token = "**"
    expected_account = get_account_event(
        username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token
    )
    cache = populated_cache(
        expected_account,
        get_account_event(  # this account matches all but the record's tenant
            username,
            object_id,
            "different-" + tenant_id,
            authority=authority,
            client_id=client_id,
            refresh_token="not-" + expected_refresh_token,
        ),
    )

    transport = msal_validating_transport(
        endpoint="https://{}/{}".format(authority, tenant_id),
        requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
        responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
    )
    credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

    token = credential.get_token("scope")
    assert token.token == expected_access_token
示例#5
0
def test_user_agent():
    transport = msal_validating_transport(
        requests=[Request(required_headers={"User-Agent": USER_AGENT})],
        responses=[mock_response(json_payload=build_aad_response(access_token="**"))],
    )

    credential = CertificateCredential("tenant-id", "client-id", CERT_PATH, transport=transport)

    credential.get_token("scope")
示例#6
0
def test_policies_configurable():
    policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock())

    transport = msal_validating_transport(
        requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token="**"))]
    )

    credential = CertificateCredential(
        "tenant-id", "client-id", CERT_PATH, policies=[ContentDecodePolicy(), policy], transport=transport
    )

    credential.get_token("scope")

    assert policy.on_request.called
def test_interactive_credential(mock_open, redirect_url):
    mock_open.side_effect = _validate_auth_request_url
    oauth_state = "state"
    client_id = "client-id"
    expected_refresh_token = "refresh-token"
    expected_token = "access-token"
    expires_in = 3600
    authority = "authority"
    tenant_id = "tenant-id"
    endpoint = "https://{}/{}".format(authority, tenant_id)

    transport = msal_validating_transport(
        endpoint="https://{}/{}".format(authority, tenant_id),
        requests=[Request(url_substring=endpoint)] + [
            Request(authority=authority,
                    url_substring=endpoint,
                    required_data={"refresh_token": expected_refresh_token})
        ],
        responses=[
            mock_response(json_payload=build_aad_response(
                access_token=expected_token,
                expires_in=expires_in,
                refresh_token=expected_refresh_token,
                uid="uid",
                utid=tenant_id,
                id_token=build_id_token(aud=client_id,
                                        object_id="uid",
                                        tenant_id=tenant_id,
                                        iss=endpoint),
                token_type="Bearer",
            )),
            mock_response(
                json_payload=build_aad_response(access_token=expected_token,
                                                expires_in=expires_in,
                                                token_type="Bearer")),
        ],
    )

    # mock local server fakes successful authentication by immediately returning a well-formed response
    auth_code_response = {"code": "authorization-code", "state": [oauth_state]}
    server_class = Mock(return_value=Mock(
        wait_for_redirect=lambda: auth_code_response))

    args = {
        "authority": authority,
        "tenant_id": tenant_id,
        "client_id": client_id,
        "transport": transport,
        "_cache": TokenCache(),
        "_server_class": server_class,
    }
    if redirect_url:  # avoid passing redirect_url=None
        args["redirect_uri"] = redirect_url

    credential = InteractiveBrowserCredential(**args)

    # The credential's auth code request includes a uuid which must be included in the redirect. Patching to
    # set the uuid requires less code here than a proper mock server.
    with patch("azure.identity._credentials.browser.uuid.uuid4",
               lambda: oauth_state):
        token = credential.get_token("scope")
    assert token.token == expected_token
    assert mock_open.call_count == 1
    assert server_class.call_count == 1

    if redirect_url:
        server_class.assert_called_once_with(redirect_url, timeout=ANY)

    # token should be cached, get_token shouldn't prompt again
    token = credential.get_token("scope")
    assert token.token == expected_token
    assert mock_open.call_count == 1
    assert server_class.call_count == 1

    # expired access token -> credential should use refresh token instead of prompting again
    now = time.time()
    with patch("time.time", lambda: now + expires_in):
        token = credential.get_token("scope")
    assert token.token == expected_token
    assert mock_open.call_count == 1

    # ensure all expected requests were sent
    assert transport.send.call_count == 4