async def test_multitenant_authentication_not_allowed():
    expected_tenant = "expected-tenant"
    expected_token = "***"

    async def send(request, **_):
        parsed = urlparse(request.url)
        tenant = parsed.path.split("/")[1]
        token = expected_token if tenant == expected_tenant else expected_token * 2
        return mock_response(json_payload=build_aad_response(access_token=token, refresh_token="**"))

    credential = AuthorizationCodeCredential(
        expected_tenant, "client-id", "authcode", "https://localhost", transport=Mock(send=send)
    )

    token = await credential.get_token("scope")
    assert token.token == expected_token

    token = await credential.get_token("scope", tenant_id=expected_tenant)
    assert token.token == expected_token

    token = await credential.get_token("scope", tenant_id="un" + expected_tenant)
    assert token.token == expected_token * 2

    with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}):
        token = await credential.get_token("scope", tenant_id="un" + expected_tenant)
        assert token.token == expected_token
async def test_multitenant_authentication():
    first_tenant = "first-tenant"
    first_token = "***"
    second_tenant = "second-tenant"
    second_token = first_token * 2

    async def send(request, **_):
        parsed = urlparse(request.url)
        tenant = parsed.path.split("/")[1]
        assert tenant in (first_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant)
        token = first_token if tenant == first_tenant else second_token
        return mock_response(json_payload=build_aad_response(access_token=token, refresh_token="**"))

    credential = AuthorizationCodeCredential(
        first_tenant,
        "client-id",
        "authcode",
        "https://localhost",
        transport=Mock(send=send),
    )
    token = await credential.get_token("scope")
    assert token.token == first_token

    token = await credential.get_token("scope", tenant_id=first_tenant)
    assert token.token == first_token

    token = await credential.get_token("scope", tenant_id=second_tenant)
    assert token.token == second_token

    # should still default to the first tenant
    token = await credential.get_token("scope")
    assert token.token == first_token
async def test_no_scopes():
    """The credential should raise ValueError when get_token is called with no scopes"""

    credential = AuthorizationCodeCredential("tenant-id", "client-id",
                                             "auth-code", "http://localhost")
    with pytest.raises(ValueError):
        await credential.get_token()
async def test_close():
    transport = AsyncMockTransport()
    credential = AuthorizationCodeCredential(
        "tenant-id", "client-id", "auth-code", "http://localhost", transport=transport
    )

    await credential.close()

    assert transport.__aexit__.call_count == 1
async def test_tenant_id():
    transport = async_validating_transport(
        requests=[Request(required_headers={"User-Agent": USER_AGENT})],
        responses=[mock_response(json_payload=build_aad_response(access_token="**"))],
    )

    credential = AuthorizationCodeCredential(
        "tenant-id", "client-id", "auth-code", "http://localhost", transport=transport
    )

    await credential.get_token("scope", tenant_id="tenant_id")
async def test_context_manager():
    transport = AsyncMockTransport()
    credential = AuthorizationCodeCredential(
        "tenant-id", "client-id", "auth-code", "http://localhost", transport=transport
    )

    async with credential:
        assert transport.__aenter__.call_count == 1

    assert transport.__aenter__.call_count == 1
    assert transport.__aexit__.call_count == 1
示例#7
0
async def test_custom_loop_used():
    credential = AuthorizationCodeCredential(client_id="client id",
                                             tenant_id="tenant id",
                                             authorization_code="auth code",
                                             redirect_uri="https://foo.bar")

    loop = Mock()

    with pytest.raises(ClientAuthenticationError):
        await credential.get_token("scope", loop=loop)

    assert loop.run_in_executor.call_count == 1
async def test_policies_configurable():
    policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock())

    async def send(*_, **__):
        return mock_response(json_payload=build_aad_response(access_token="**"))

    credential = AuthorizationCodeCredential(
        "tenant-id", "client-id", "auth-code", "http://localhost", policies=[policy], transport=Mock(send=send)
    )

    await credential.get_token("scope")

    assert policy.on_request.called
示例#9
0
async def test_auth_code_credential():
    client_id = "client id"
    tenant_id = "tenant"
    expected_code = "auth code"
    redirect_uri = "https://foo.bar"
    expected_token = AccessToken("token", 42)

    mock_client = Mock(spec=object)
    obtain_by_auth_code = Mock(return_value=expected_token)
    mock_client.obtain_token_by_authorization_code = asyncio.coroutine(
        obtain_by_auth_code)

    credential = AuthorizationCodeCredential(
        client_id=client_id,
        tenant_id=tenant_id,
        authorization_code=expected_code,
        redirect_uri=redirect_uri,
        client=mock_client,
    )

    # first call should redeem the auth code
    token = await credential.get_token("scope")
    assert token is expected_token
    assert obtain_by_auth_code.call_count == 1
    _, kwargs = obtain_by_auth_code.call_args
    assert kwargs["code"] == expected_code

    # no auth code -> credential should return cached token
    mock_client.obtain_token_by_authorization_code = None  # raise if credential calls this again
    mock_client.get_cached_access_token = lambda *_: expected_token
    token = await credential.get_token("scope")
    assert token is expected_token

    # no auth code, no cached token -> credential should use refresh token
    mock_client.get_cached_access_token = lambda *_: None
    mock_client.get_cached_refresh_tokens = lambda *_: [
        "this is a refresh token"
    ]
    mock_client.obtain_token_by_refresh_token = asyncio.coroutine(
        lambda *_, **__: expected_token)
    token = await credential.get_token("scope")
    assert token is expected_token
示例#10
0
async def test_multitenant_authentication_not_allowed():
    """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)"""

    expected_tenant = "expected-tenant"
    expected_token = "***"

    async def send(request, **_):
        parsed = urlparse(request.url)
        tenant = parsed.path.split("/")[1]
        token = expected_token if tenant == expected_tenant else expected_token * 2
        return mock_response(json_payload=build_aad_response(
            access_token=token, refresh_token="**"))

    credential = AuthorizationCodeCredential(expected_tenant,
                                             "client-id",
                                             "authcode",
                                             "https://localhost",
                                             transport=Mock(send=send))

    token = await credential.get_token("scope")
    assert token.token == expected_token

    # explicitly specifying the configured tenant is okay
    token = await credential.get_token("scope", tenant_id=expected_tenant)
    assert token.token == expected_token

    # but any other tenant should get an error
    with pytest.raises(ClientAuthenticationError,
                       match="allow_multitenant_authentication"):
        await credential.get_token("scope", tenant_id="un" + expected_tenant)

    # ...unless the compat switch is enabled
    with patch.dict("os.environ", {
            EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION:
            "true"
    }):
        token = await credential.get_token("scope",
                                           tenant_id="un" + expected_tenant)
    assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled"
async def test_auth_code_credential():
    client_id = "client id"
    tenant_id = "tenant"
    expected_code = "auth code"
    redirect_uri = "https://localhost"
    expected_access_token = "access"
    expected_refresh_token = "refresh"
    expected_scope = "scope"

    auth_response = build_aad_response(access_token=expected_access_token, refresh_token=expected_refresh_token)
    transport = async_validating_transport(
        requests=[
            Request(  # first call should redeem the auth code
                url_substring=tenant_id,
                required_data={
                    "client_id": client_id,
                    "code": expected_code,
                    "grant_type": "authorization_code",
                    "redirect_uri": redirect_uri,
                    "scope": expected_scope,
                },
            ),
            Request(  # third call should redeem the refresh token
                url_substring=tenant_id,
                required_data={
                    "client_id": client_id,
                    "grant_type": "refresh_token",
                    "refresh_token": expected_refresh_token,
                    "scope": expected_scope,
                },
            ),
        ],
        responses=[mock_response(json_payload=auth_response)] * 2,
    )
    cache = msal.TokenCache()

    credential = AuthorizationCodeCredential(
        client_id=client_id,
        tenant_id=tenant_id,
        authorization_code=expected_code,
        redirect_uri=redirect_uri,
        transport=transport,
        cache=cache,
    )

    # first call should redeem the auth code
    token = await credential.get_token(expected_scope)
    assert token.token == expected_access_token
    assert transport.send.call_count == 1

    # no auth code -> credential should return cached token
    token = await credential.get_token(expected_scope)
    assert token.token == expected_access_token
    assert transport.send.call_count == 1

    # no auth code, no cached token -> credential should redeem refresh token
    cached_access_token = cache.find(cache.CredentialType.ACCESS_TOKEN)[0]
    cache.remove_at(cached_access_token)
    token = await credential.get_token(expected_scope)
    assert token.token == expected_access_token
    assert transport.send.call_count == 2