def test_request_url(authority):
    tenant_id = "expected-tenant"
    parsed_authority = urlparse(authority)
    expected_netloc = parsed_authority.netloc or authority  # "localhost" parses to netloc "", path "localhost"

    def send(request, **_):
        actual = urlparse(request.url)
        assert actual.scheme == "https"
        assert actual.netloc == expected_netloc
        assert actual.path.startswith("/" + tenant_id)
        return mock_response(json_payload={
            "token_type": "Bearer",
            "expires_in": 42,
            "access_token": "***"
        })

    client = AadClient(tenant_id,
                       "client id",
                       transport=Mock(send=send),
                       authority=authority)

    client.obtain_token_by_authorization_code("scope", "code", "uri")
    client.obtain_token_by_refresh_token("scope", "refresh token")

    # authority can be configured via environment variable
    with patch.dict("os.environ",
                    {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority},
                    clear=True):
        client = AadClient(tenant_id=tenant_id,
                           client_id="client id",
                           transport=Mock(send=send))
    client.obtain_token_by_authorization_code("scope", "code", "uri")
    client.obtain_token_by_refresh_token("scope", "refresh token")
def test_retries_token_requests():
    """The client should retry token requests"""

    message = "can't connect"
    transport = Mock(send=Mock(side_effect=ServiceRequestError(message)))
    client = AadClient("tenant-id", "client-id", transport=transport)

    with pytest.raises(ServiceRequestError, match=message):
        client.obtain_token_by_authorization_code("", "", "")
    assert transport.send.call_count > 1
    transport.send.reset_mock()

    with pytest.raises(ServiceRequestError, match=message):
        client.obtain_token_by_client_certificate(
            "", AadClientCertificate(open(CERT_PATH, "rb").read()))
    assert transport.send.call_count > 1
    transport.send.reset_mock()

    with pytest.raises(ServiceRequestError, match=message):
        client.obtain_token_by_client_secret("", "")
    assert transport.send.call_count > 1
    transport.send.reset_mock()

    with pytest.raises(ServiceRequestError, match=message):
        client.obtain_token_by_refresh_token("", "")
    assert transport.send.call_count > 1
    transport.send.reset_mock()
def test_refresh_token():
    tenant_id = "tenant-id"
    client_id = "client-id"
    scope = "scope"
    refresh_token = "refresh-token"
    access_token = "***"

    def send(request, **_):
        assert request.data["client_id"] == client_id
        assert request.data["grant_type"] == "refresh_token"
        assert request.data["refresh_token"] == refresh_token
        assert request.data["scope"] == scope

        return mock_response(json_payload={
            "access_token": access_token,
            "expires_in": 42
        })

    transport = Mock(send=Mock(wraps=send))

    client = AadClient(tenant_id, client_id, transport=transport)
    token = client.obtain_token_by_refresh_token(scopes=(scope, ),
                                                 refresh_token=refresh_token)

    assert token.token == access_token
    assert transport.send.call_count == 1
def test_evicts_invalid_refresh_token():
    """when AAD rejects a refresh token, the client should evict that token from its cache"""

    tenant_id = "tenant-id"
    client_id = "client-id"
    invalid_token = "invalid-refresh-token"

    cache = TokenCache()
    cache.add({
        "response":
        build_aad_response(uid="id1",
                           utid="tid1",
                           access_token="*",
                           refresh_token=invalid_token)
    })
    cache.add({
        "response":
        build_aad_response(uid="id2",
                           utid="tid2",
                           access_token="*",
                           refresh_token="...")
    })
    assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 2
    assert len(
        cache.find(TokenCache.CredentialType.REFRESH_TOKEN,
                   query={"secret": invalid_token})) == 1

    def send(request, **_):
        assert request.data["refresh_token"] == invalid_token
        return mock_response(json_payload={"error": "invalid_grant"},
                             status_code=400)

    transport = Mock(send=Mock(wraps=send))

    client = AadClient(tenant_id, client_id, transport=transport, cache=cache)
    with pytest.raises(ClientAuthenticationError):
        client.obtain_token_by_refresh_token(scopes=("scope", ),
                                             refresh_token=invalid_token)

    assert transport.send.call_count == 1
    assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 1
    assert len(
        cache.find(TokenCache.CredentialType.REFRESH_TOKEN,
                   query={"secret": invalid_token})) == 0