def test_evicts_invalid_refresh_token():
    """when AAD rejects a refresh token, the client should evict that token from its cache"""

    tenant_id = "tenant-id"
    client_id = "client-id"
    invalid_token = "invalid-refresh-token"

    cache = TokenCache()
    cache.add({"response": build_aad_response(uid="id1", utid="tid1", access_token="*", refresh_token=invalid_token)})
    cache.add({"response": build_aad_response(uid="id2", utid="tid2", access_token="*", refresh_token="...")})
    assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 2
    assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": invalid_token})) == 1

    def send(request, **_):
        assert request.data["refresh_token"] == invalid_token
        return mock_response(json_payload={"error": "invalid_grant"}, status_code=400)

    transport = Mock(send=Mock(wraps=send))

    client = AadClient(tenant_id, client_id, transport=transport, cache=cache)
    with pytest.raises(ClientAuthenticationError):
        client.obtain_token_by_refresh_token(scopes=("scope",), refresh_token=invalid_token)

    assert transport.send.call_count == 1
    assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 1
    assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": invalid_token})) == 0
Ejemplo n.º 2
0
class AuthnClientBase(object):
    """Sans I/O authentication client methods"""
    def __init__(self, auth_url, **kwargs):
        # type: (str, Mapping[str, Any]) -> None
        if not auth_url:
            raise ValueError("auth_url should be the URL of an OAuth endpoint")
        super(AuthnClientBase, self).__init__()
        self._auth_url = auth_url
        self._cache = TokenCache()

    def get_cached_token(self, scopes):
        # type: (Iterable[str]) -> Optional[str]
        tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN,
                                  list(scopes))
        for token in tokens:
            if all((scope in token["target"] for scope in scopes)):
                if int(token["expires_on"]) - 300 > int(time()):
                    return token["secret"]
        return None

    def _deserialize_and_cache_token(self, response, scopes):
        # type: (PipelineResponse, Iterable[str]) -> str
        try:
            if "deserialized_data" in response.context:
                payload = response.context["deserialized_data"]
            else:
                payload = response.http_response.text()
            token = payload["access_token"]

            # these values are strings in IMDS responses but msal.TokenCache requires they be integers
            # https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/55
            if payload.get("expires_in"):
                payload["expires_in"] = int(payload["expires_in"])
            if payload.get("ext_expires_in"):
                payload["ext_expires_in"] = int(payload["ext_expires_in"])

            self._cache.add({"response": payload, "scope": scopes})
            return token
        except KeyError:
            raise AuthenticationError(
                "Unexpected authentication response: {}".format(payload))
        except Exception as ex:
            raise AuthenticationError("Authentication failed: {}".format(
                str(ex)))

    def _prepare_request(self,
                         method="POST",
                         headers=None,
                         form_data=None,
                         params=None):
        # type: (Optional[str], Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> HttpRequest
        request = HttpRequest(method, self._auth_url, headers=headers)
        if form_data:
            request.headers[
                "Content-Type"] = "application/x-www-form-urlencoded"
            request.set_formdata_body(form_data)
        if params:
            request.format_parameters(params)
        return request
Ejemplo n.º 3
0
def test_persistent_cache_multiple_clients(cert_path, cert_password):
    """the credential shouldn't use tokens issued to other service principals"""

    access_token_a = "token a"
    access_token_b = "not " + access_token_a
    transport_a = msal_validating_transport(
        requests=[Request()],
        responses=[
            mock_response(json_payload=build_aad_response(
                access_token=access_token_a))
        ])
    transport_b = msal_validating_transport(
        requests=[Request()],
        responses=[
            mock_response(json_payload=build_aad_response(
                access_token=access_token_b))
        ])

    cache = TokenCache()
    with patch(
            "azure.identity._internal.msal_credentials._load_persistent_cache"
    ) as mock_cache_loader:
        mock_cache_loader.return_value = Mock(wraps=cache)
        credential_a = CertificateCredential(
            "tenant",
            "client-a",
            cert_path,
            password=cert_password,
            transport=transport_a,
            cache_persistence_options=TokenCachePersistenceOptions(),
        )
        assert mock_cache_loader.call_count == 1, "credential should load the persistent cache"

        credential_b = CertificateCredential(
            "tenant",
            "client-b",
            cert_path,
            password=cert_password,
            transport=transport_b,
            cache_persistence_options=TokenCachePersistenceOptions(),
        )
        assert mock_cache_loader.call_count == 2, "credential should load the persistent cache"

    # A caches a token
    scope = "scope"
    token_a = credential_a.get_token(scope)
    assert token_a.token == access_token_a
    assert transport_a.send.call_count == 3  # two MSAL discovery requests, one token request

    # B should get a different token for the same scope
    token_b = credential_b.get_token(scope)
    assert token_b.token == access_token_b
    assert transport_b.send.call_count == 3

    assert len(cache.find(TokenCache.CredentialType.ACCESS_TOKEN)) == 2
Ejemplo n.º 4
0
def test_writes_to_cache():
    """the credential should write tokens it acquires to the cache"""

    scope = "scope"
    expected_access_token = "access token"
    first_refresh_token = "first refresh token"
    second_refresh_token = "second refresh token"

    username = "******"
    uid = "uid"
    utid = "utid"
    account = get_account_event(username=username, uid=uid, utid=utid, refresh_token=first_refresh_token)
    cache = TokenCache()
    cache.add(account)

    transport = validating_transport(
        requests=[Request(required_data={"refresh_token": first_refresh_token})],  # credential redeems refresh token
        responses=[
            mock_response(
                json_payload=build_aad_response(  # AAD responds with an access token and new refresh token
                    uid=uid,
                    utid=utid,
                    access_token=expected_access_token,
                    refresh_token=second_refresh_token,
                    id_token=build_id_token(
                        aud=DEVELOPER_SIGN_ON_CLIENT_ID, object_id=uid, tenant_id=utid, username=username
                    ),
                )
            )
        ],
    )
    credential = SharedTokenCacheCredential(_cache=cache, transport=transport)
    token = credential.get_token(scope)
    assert token.token == expected_access_token

    # access token should be in the cache, and another instance should retrieve it
    credential = SharedTokenCacheCredential(
        _cache=cache, transport=Mock(send=Mock(side_effect=Exception("the credential should return a cached token")))
    )
    token = credential.get_token(scope)
    assert token.token == expected_access_token

    # and the credential should have updated the cached refresh token
    second_access_token = "second access token"
    transport = validating_transport(
        requests=[Request(required_data={"refresh_token": second_refresh_token})],
        responses=[mock_response(json_payload=build_aad_response(access_token=second_access_token))],
    )
    credential = SharedTokenCacheCredential(_cache=cache, transport=transport)
    token = credential.get_token("some other " + scope)
    assert token.token == second_access_token

    # verify the credential didn't add a new cache entry
    assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 1
Ejemplo n.º 5
0
async def test_cache_multiple_clients():
    """the credential shouldn't use tokens issued to other service principals"""

    access_token_a = "token a"
    access_token_b = "not " + access_token_a
    transport_a = async_validating_transport(
        requests=[Request()],
        responses=[
            mock_response(json_payload=build_aad_response(
                access_token=access_token_a))
        ])
    transport_b = async_validating_transport(
        requests=[Request()],
        responses=[
            mock_response(json_payload=build_aad_response(
                access_token=access_token_b))
        ])

    cache = TokenCache()
    with patch(ClientSecretCredential.__module__ +
               "._load_persistent_cache") as mock_cache_loader:
        mock_cache_loader.return_value = Mock(wraps=cache)
        credential_a = ClientSecretCredential(
            "tenant",
            "client-a",
            "secret",
            transport=transport_a,
            cache_persistence_options=TokenCachePersistenceOptions(),
        )
        assert mock_cache_loader.call_count == 1, "credential should load the persistent cache"

        credential_b = ClientSecretCredential(
            "tenant",
            "client-b",
            "secret",
            transport=transport_b,
            cache_persistence_options=TokenCachePersistenceOptions(),
        )
        assert mock_cache_loader.call_count == 2, "credential should load the persistent cache"

    # A caches a token
    scope = "scope"
    token_a = await credential_a.get_token(scope)
    assert token_a.token == access_token_a
    assert transport_a.send.call_count == 1

    # B should get a different token for the same scope
    token_b = await credential_b.get_token(scope)
    assert token_b.token == access_token_b
    assert transport_b.send.call_count == 1

    assert len(cache.find(TokenCache.CredentialType.ACCESS_TOKEN)) == 2
Ejemplo n.º 6
0
class AuthnClientBase(object):
    """Sans I/O authentication client methods"""
    def __init__(self, auth_url, **kwargs):
        # type: (str, Mapping[str, Any]) -> None
        if not auth_url:
            raise ValueError("auth_url should be the URL of an OAuth endpoint")
        super(AuthnClientBase, self).__init__()
        self._auth_url = auth_url
        self._cache = TokenCache()

    def get_cached_token(self, scopes):
        # type: (Iterable[str]) -> Optional[AccessToken]
        tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN,
                                  list(scopes))
        for token in tokens:
            if all((scope in token["target"] for scope in scopes)):
                expires_on = int(token["expires_on"])
                if expires_on - 300 > int(time.time()):
                    return AccessToken(token["secret"], expires_on)
        return None

    def _deserialize_and_cache_token(self, response, scopes, request_time):
        # type: (PipelineResponse, Iterable[str], int) -> AccessToken

        # ContentDecodePolicy sets this, and should have raised if it couldn't deserialize the response
        payload = response.context[ContentDecodePolicy.CONTEXT_NAME]

        if not payload or "access_token" not in payload or not (
                "expires_in" in payload or "expires_on" in payload):
            if payload and "access_token" in payload:
                payload["access_token"] = "****"
            raise ClientAuthenticationError(
                message="Unexpected response '{}'".format(payload))

        token = payload["access_token"]

        # AccessToken wants expires_on as an int
        expires_on = payload.get("expires_on") or int(
            payload["expires_in"]) + request_time
        try:
            expires_on = int(expires_on)
        except ValueError:
            # probably an App Service MSI response, convert it to epoch seconds
            try:
                t = self._parse_app_service_expires_on(expires_on)
                expires_on = calendar.timegm(t)
            except ValueError:
                # have a token but don't know when it expires -> treat it as single-use
                expires_on = request_time

        # now we have an int expires_on, ensure the cache entry gets it
        payload["expires_on"] = expires_on

        self._cache.add({"response": payload, "scope": scopes})

        return AccessToken(token, expires_on)

    @staticmethod
    def _parse_app_service_expires_on(expires_on):
        # type: (str) -> struct_time
        """
        Parse expires_on from an App Service MSI response (e.g. "06/19/2019 23:42:01 +00:00") to struct_time.
        Expects the time is given in UTC (i.e. has offset +00:00).
        """
        if not expires_on.endswith(" +00:00"):
            raise ValueError(
                "'{}' doesn't match expected format".format(expires_on))

        # parse the string minus the timezone offset
        return time.strptime(expires_on[:-len(" +00:00")], "%m/%d/%Y %H:%M:%S")

    # TODO: public, factor out of request_token
    def _prepare_request(self,
                         method="POST",
                         headers=None,
                         form_data=None,
                         params=None):
        # type: (Optional[str], Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> HttpRequest
        request = HttpRequest(method, self._auth_url, headers=headers)
        if form_data:
            request.headers[
                "Content-Type"] = "application/x-www-form-urlencoded"
            request.set_formdata_body(form_data)
        if params:
            request.format_parameters(params)
        return request
class AadClientBase(ABC):
    """Sans I/O methods for AAD clients wrapping MSAL's OAuth client"""
    def __init__(self, tenant_id, client_id, **kwargs):
        # type: (str, str, **Any) -> None
        authority = kwargs.pop("authority",
                               KnownAuthorities.AZURE_PUBLIC_CLOUD)
        if authority[-1] == "/":
            authority = authority[:-1]
        token_endpoint = "https://" + "/".join(
            (authority, tenant_id, "oauth2/v2.0/token"))
        config = {"token_endpoint": token_endpoint}

        self._client = Client(server_configuration=config, client_id=client_id)
        self._client.session.close()
        self._client.session = self._get_client_session(**kwargs)
        self._cache = TokenCache()

    def get_cached_access_token(self, scopes):
        # type: (Iterable[str]) -> Optional[AccessToken]
        tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN,
                                  target=list(scopes))
        for token in tokens:
            expires_on = int(token["expires_on"])
            if expires_on - 300 > int(time.time()):
                return AccessToken(token["secret"], expires_on)
        return None

    def get_cached_refresh_tokens(self, scopes):
        """Assumes all cached refresh tokens belong to the same user"""
        return self._cache.find(TokenCache.CredentialType.REFRESH_TOKEN,
                                target=list(scopes))

    def obtain_token_by_authorization_code(self, code, redirect_uri, scopes,
                                           **kwargs):
        # type: (str, str, Iterable[str], **Any) -> AccessToken
        fn = functools.partial(self._client.obtain_token_by_authorization_code,
                               code=code,
                               redirect_uri=redirect_uri,
                               **kwargs)
        return self._obtain_token(scopes, fn, **kwargs)

    def obtain_token_by_refresh_token(self, refresh_token, scopes, **kwargs):
        # type: (str, Iterable[str], **Any) -> AccessToken
        fn = functools.partial(self._client.obtain_token_by_refresh_token,
                               token_item=refresh_token,
                               scope=scopes,
                               rt_getter=lambda token: token["secret"],
                               **kwargs)
        return self._obtain_token(scopes, fn, **kwargs)

    def _process_response(self, response, scopes, now):
        # type: (dict, Iterable[str], int) -> AccessToken
        _raise_for_error(response)
        self._cache.add(event={"response": response, "scope": scopes}, now=now)
        if "expires_on" in response:
            expires_on = int(response["expires_on"])
        elif "expires_in" in response:
            expires_on = now + int(response["expires_in"])
        else:
            _scrub_secrets(response)
            raise ClientAuthenticationError(
                message="Unexpected response from Azure Active Directory: {}".
                format(response))
        return AccessToken(response["access_token"], expires_on)

    @abc.abstractmethod
    def _get_client_session(self, **kwargs):
        pass

    @abc.abstractmethod
    def _obtain_token(self, scopes, fn, **kwargs):
        # type: (Iterable[str], Callable, **Any) -> AccessToken
        pass
Ejemplo n.º 8
0
class AuthnClientBase(object):
    """Sans I/O authentication client methods"""
    def __init__(self, auth_url, **kwargs):
        # type: (str, Mapping[str, Any]) -> None
        if not auth_url:
            raise ValueError("auth_url should be the URL of an OAuth endpoint")
        super(AuthnClientBase, self).__init__()
        self._auth_url = auth_url
        self._cache = TokenCache()

    def get_cached_token(self, scopes):
        # type: (Iterable[str]) -> Optional[AccessToken]
        tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN,
                                  list(scopes))
        for token in tokens:
            if all((scope in token["target"] for scope in scopes)):
                expires_on = int(token["expires_on"])
                if expires_on - 300 > int(time.time()):
                    return AccessToken(token["secret"], expires_on)
        return None

    def _deserialize_and_cache_token(self, response, scopes, request_time):
        # type: (PipelineResponse, Iterable[str], int) -> str
        try:
            if "deserialized_data" in response.context:
                payload = response.context["deserialized_data"]
            else:
                payload = response.http_response.text()
            token = payload["access_token"]

            # these values are strings in IMDS responses but msal.TokenCache requires they be integers
            # https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/55
            expires_in = int(payload.get("expires_in", 0))
            if expires_in != 0:
                payload["expires_in"] = expires_in
            if "ext_expires_in" in payload:
                payload["ext_expires_in"] = int(payload["ext_expires_in"])

            self._cache.add({"response": payload, "scope": scopes})

            # AccessToken contains the token's expires_on time. There are four cases for setting it:
            # 1. response has expires_on -> AccessToken uses it
            # 2. response has expires_on and expires_in -> AccessToken uses expires_on
            # 3. response has only expires_in -> AccessToken uses expires_in + time of request
            # 4. response has neither expires_on or expires_in -> AccessToken sets expires_on = 0
            #    (not expecting this case; if it occurs, the token is effectively single-use)
            expires_on = payload.get("expires_on", 0)
            return AccessToken(token, expires_on or expires_in + request_time)
        except KeyError:
            if "access_token" in payload:
                payload["access_token"] = "****"
            raise AuthenticationError(
                "Unexpected authentication response: {}".format(payload))
        except Exception as ex:
            raise AuthenticationError("Authentication failed: {}".format(
                str(ex)))

    # TODO: public, factor out of request_token
    def _prepare_request(self,
                         method="POST",
                         headers=None,
                         form_data=None,
                         params=None):
        # type: (Optional[str], Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> HttpRequest
        request = HttpRequest(method, self._auth_url, headers=headers)
        if form_data:
            request.headers[
                "Content-Type"] = "application/x-www-form-urlencoded"
            request.set_formdata_body(form_data)
        if params:
            request.format_parameters(params)
        return request