def _deserialize_and_cache_token(self, response, scopes, request_time): # type: (PipelineResponse, Iterable[str], int) -> AccessToken """Deserialize and cache an access token from an AAD response""" # ContentDecodePolicy sets this, and should have raised if it couldn't deserialize the response payload = response.context[ContentDecodePolicy.CONTEXT_NAME] if not payload or "access_token" not in payload or not ( "expires_in" in payload or "expires_on" in payload): if payload and "access_token" in payload: payload["access_token"] = "****" raise ClientAuthenticationError( message="Unexpected response '{}'".format(payload)) token = payload["access_token"] # AccessToken wants expires_on as an int expires_on = payload.get("expires_on") or int( payload["expires_in"]) + request_time # type: Union[str, int] try: expires_on = int(expires_on) except ValueError: # probably an App Service MSI response, convert it to epoch seconds try: t = self._parse_app_service_expires_on( expires_on) # type: ignore expires_on = calendar.timegm(t) except ValueError: # have a token but don't know when it expires -> treat it as single-use expires_on = request_time # now we have an int expires_on, ensure the cache entry gets it payload["expires_on"] = expires_on self._cache.add({"response": payload, "scope": scopes}) return AccessToken(token, expires_on)
def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken """Request an access token for `scopes`. .. note:: This method is called by Azure SDK clients. It isn't intended for use in application code. :param str scopes: desired scopes for the access token. This method requires at least one scope. :rtype: :class:`azure.core.credentials.AccessToken` :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks required data, state, or platform support :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` attribute gives a reason. :raises AuthenticationRequiredError: user interaction is necessary to acquire a token, and the credential is configured not to begin this automatically. Call :func:`authenticate` to begin interactive authentication. """ if not scopes: raise ValueError("'get_token' requires at least one scope") allow_prompt = kwargs.pop("_allow_prompt", not self._disable_automatic_authentication) try: return self._acquire_token_silent(*scopes, **kwargs) except AuthenticationRequiredError: if not allow_prompt: raise # silent authentication failed -> authenticate interactively now = int(time.time()) result = self._request_token(*scopes, **kwargs) if "access_token" not in result: message = "Authentication failed: {}".format(result.get("error_description") or result.get("error")) raise ClientAuthenticationError(message=message) # this may be the first authentication, or the user may have authenticated a different identity self._auth_record = _build_auth_record(result) return AccessToken(result["access_token"], now + int(result["expires_in"]))
def test_cloud_shell(): """Cloud Shell environment: only MSI_ENDPOINT set""" access_token = "****" expires_on = 42 expected_token = AccessToken(access_token, expires_on) endpoint = "http://localhost:42/token" scope = "scope" transport = validating_transport( requests=[ Request( base_url=endpoint, method="POST", required_headers={ "Metadata": "true", "User-Agent": USER_AGENT }, required_data={"resource": scope}, ) ], responses=[ mock_response( json_payload={ "access_token": access_token, "expires_in": 0, "expires_on": expires_on, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", }) ], ) with mock.patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: endpoint}): token = ManagedIdentityCredential(transport=transport).get_token(scope) assert token == expected_token
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # type: (*str, **Any) -> AccessToken """ Request an access token for `scopes`. :param str scopes: desired scopes for the token :rtype: :class:`azure.core.credentials.AccessToken` :raises: :class:`azure.core.exceptions.ClientAuthenticationError` """ # MSAL requires scopes be a list scopes = list(scopes) # type: ignore now = int(time.time()) app = self._get_app() accounts = app.get_accounts(username=self._username) result = None for account in accounts: result = app.acquire_token_silent(scopes, account=account) if result: break if not result: # cache miss -> request a new token with self._adapter: result = app.acquire_token_by_username_password( username=self._username, password=self._password, scopes=scopes) if "access_token" not in result: raise ClientAuthenticationError( message="authentication failed: {}".format( result.get("error_description"))) return AccessToken(result["access_token"], now + int(result["expires_in"]))
def _process_response(self, response, request_time): # type: (PipelineResponse, int) -> AccessToken # ContentDecodePolicy sets this, and should have raised if it couldn't deserialize the response content = ContentDecodePolicy.deserialize_from_http_generics( response.http_response) # type: dict if not content: raise ClientAuthenticationError(message="No token received.", response=response.http_response) if "access_token" not in content or not ("expires_in" in content or "expires_on" in content): if content and "access_token" in content: content["access_token"] = "****" raise ClientAuthenticationError( message='Unexpected response "{}"'.format(content), response=response.http_response) if self._content_callback: self._content_callback(content) expires_on = int( content.get("expires_on") or int(content["expires_in"]) + request_time) content["expires_on"] = expires_on token = AccessToken(content["access_token"], content["expires_on"]) # caching is the final step because TokenCache.add mutates its "event" self._cache.add( event={ "response": content, "scope": content["resource"] }, now=request_time, ) return token
def _acquire_token_silent(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken """Silently acquire a token from MSAL. Requires an AuthenticationRecord.""" # self._auth_record and ._app will not be None when this method is called by get_token # but should either be None anyway (and to satisfy mypy) we raise if self._app is None or self._auth_record is None: raise CredentialUnavailableError("Initialization failed") result = None accounts_for_user = self._app.get_accounts(username=self._auth_record.username) if not accounts_for_user: raise CredentialUnavailableError("The cache contains no account matching the given AuthenticationRecord.") for account in accounts_for_user: if account.get("home_account_id") != self._auth_record.home_account_id: continue now = int(time.time()) result = self._app.acquire_token_silent_with_error( list(scopes), account=account, claims_challenge=kwargs.get("claims") ) if result and "access_token" in result and "expires_in" in result: return AccessToken(result["access_token"], now + int(result["expires_in"])) # if we get this far, the cache contained a matching account but MSAL failed to authenticate it silently if result: # cache contains a matching refresh token but STS returned an error response when MSAL tried to use it message = "Token acquisition failed" details = result.get("error_description") or result.get("error") if details: message += ": {}".format(details) raise ClientAuthenticationError(message=message) # cache doesn't contain a matching refresh (or access) token raise CredentialUnavailableError(message=NO_TOKEN.format(self._auth_record.username))
def test_imds_user_assigned_identity(): access_token = "****" expires_on = 42 expected_token = AccessToken(access_token, expires_on) url = Endpoints.IMDS scope = "scope" client_id = "some-guid" transport = validating_transport( requests=[ Request(url), # first request should be availability probe => match only the URL Request( url, method="GET", required_headers={"Metadata": "true"}, required_params={"api-version": "2018-02-01", "client_id": client_id, "resource": scope}, ), ], responses=[ # probe receives error response mock_response(status_code=400, json_payload={"error": "this is an error message"}), mock_response( json_payload={ "access_token": access_token, "client_id": client_id, "expires_in": 42, "expires_on": expires_on, "ext_expires_in": 42, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", } ), ], ) token = ManagedIdentityCredential(client_id=client_id, transport=transport).get_token(scope) assert token == expected_token
def test_bearer_policy_adds_header(): """The bearer token policy should add a header containing a token from its credential""" # 2524608000 == 01/01/2050 @ 12:00am (UTC) expected_token = AccessToken("expected_token", 2524608000) def verify_authorization_header(request): assert request.http_request.headers[ "Authorization"] == "Bearer {}".format(expected_token.token) fake_credential = Mock(get_token=Mock(return_value=expected_token)) policies = [ BearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header) ] pipeline = Pipeline(transport=Mock(), policies=policies) pipeline.run(HttpRequest("GET", "https://spam.eggs")) assert fake_credential.get_token.call_count == 1 pipeline.run(HttpRequest("GET", "https://spam.eggs")) # Didn't need a new token assert fake_credential.get_token.call_count == 1
def _acquire_token_silent(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken result = None if self._auth_record: app = self._get_app() for account in app.get_accounts( username=self._auth_record.username): if account.get("home_account_id" ) != self._auth_record.home_account_id: continue now = int(time.time()) result = app.acquire_token_silent_with_error(list(scopes), account=account, **kwargs) if result and "access_token" in result and "expires_in" in result: return AccessToken(result["access_token"], now + int(result["expires_in"])) # if we get this far, result is either None or the content of an AAD error response if result: details = result.get("error_description") or result.get("error") raise AuthenticationRequiredError(scopes, error_details=details) raise AuthenticationRequiredError(scopes)
def test_imds(): access_token = "****" expires_on = 42 expected_token = AccessToken(access_token, expires_on) scope = "scope" transport = validating_transport( requests=[ Request(url=Endpoints.IMDS), # first request should be availability probe => match only the URL Request( base_url=Endpoints.IMDS, method="GET", required_headers={"Metadata": "true", "User-Agent": USER_AGENT}, required_params={"api-version": "2018-02-01", "resource": scope}, ), ], responses=[ # probe receives error response mock_response(status_code=400, json_payload={"error": "this is an error message"}), mock_response( json_payload={ "access_token": access_token, "expires_in": 42, "expires_on": expires_on, "ext_expires_in": 42, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", } ), ], ) # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential with mock.patch.dict("os.environ", clear=True): token = ManagedIdentityCredential(transport=transport).get_token(scope) assert token == expected_token
async def test_imds_tenant_id(): access_token = "****" expires_on = 42 expected_token = AccessToken(access_token, expires_on) scope = "scope" transport = async_validating_transport( requests=[ Request(base_url=IMDS_AUTHORITY + IMDS_TOKEN_PATH), Request( base_url=IMDS_AUTHORITY + IMDS_TOKEN_PATH, method="GET", required_headers={"Metadata": "true", "User-Agent": USER_AGENT}, required_params={"api-version": "2018-02-01", "resource": scope}, ), ], responses=[ # probe receives error response mock_response(status_code=400, json_payload={"error": "this is an error message"}), mock_response( json_payload={ "access_token": access_token, "expires_in": 42, "expires_on": expires_on, "ext_expires_in": 42, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", } ), ], ) # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential with mock.patch.dict("os.environ", clear=True): token = await ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id") assert token == expected_token
def test_app_service_user_assigned_identity(): """App Service 2017-09-01: MSI_ENDPOINT, MSI_SECRET set""" access_token = "****" expires_on = 42 client_id = "some-guid" expected_token = AccessToken(access_token, expires_on) endpoint = "http://localhost:42/token" secret = "expected-secret" scope = "scope" transport = validating_transport( requests=[ Request( base_url=endpoint, method="GET", required_headers={"secret": secret, "User-Agent": USER_AGENT}, required_params={"api-version": "2017-09-01", "clientid": client_id, "resource": scope}, ) ], responses=[ mock_response( json_payload={ "access_token": access_token, "expires_on": "01/01/1970 00:00:{} +00:00".format(expires_on), "resource": scope, "token_type": "Bearer", } ) ], ) with mock.patch( "os.environ", {EnvironmentVariables.MSI_ENDPOINT: endpoint, EnvironmentVariables.MSI_SECRET: secret} ): token = ManagedIdentityCredential(client_id=client_id, transport=transport).get_token(scope) assert token == expected_token
def test_bearer_policy_calls_sansio_methods(): """BearerTokenCredentialPolicy should call SansIOHttpPolicy methods as does _SansIOHTTPPolicyRunner""" class TestPolicy(BearerTokenCredentialPolicy): def __init__(self, *args, **kwargs): super(TestPolicy, self).__init__(*args, **kwargs) self.on_exception = Mock(return_value=False) self.on_request = Mock() self.on_response = Mock() def send(self, request): self.request = request self.response = super(TestPolicy, self).send(request) return self.response credential = Mock(get_token=Mock( return_value=AccessToken("***", int(time.time()) + 3600))) policy = TestPolicy(credential, "scope") transport = Mock(send=Mock(return_value=Mock(status_code=200))) pipeline = Pipeline(transport=transport, policies=[policy]) pipeline.run(HttpRequest("GET", "https://localhost")) policy.on_request.assert_called_once_with(policy.request) policy.on_response.assert_called_once_with(policy.request, policy.response) # the policy should call on_exception when next.send() raises class TestException(Exception): pass transport = Mock(send=Mock(side_effect=TestException)) policy = TestPolicy(credential, "scope") pipeline = Pipeline(transport=transport, policies=[policy]) with pytest.raises(TestException): pipeline.run(HttpRequest("GET", "https://localhost")) policy.on_exception.assert_called_once_with(policy.request)
def test_bearer_policy_optionally_enforces_https(): """HTTPS enforcement should be controlled by a keyword argument, and enabled by default""" def assert_option_popped(request, **kwargs): assert "enforce_https" not in kwargs, "BearerTokenCredentialPolicy didn't pop the 'enforce_https' option" credential = Mock(get_token=lambda *_, **__: AccessToken("***", 42)) pipeline = Pipeline( transport=Mock(send=assert_option_popped), policies=[BearerTokenCredentialPolicy(credential, "scope")] ) # by default and when enforce_https=True, the policy should raise when given an insecure request with pytest.raises(ServiceRequestError): pipeline.run(HttpRequest("GET", "http://not.secure")) with pytest.raises(ServiceRequestError): pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=True) # when enforce_https=False, an insecure request should pass pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False) # https requests should always pass pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=False) pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=True) pipeline.run(HttpRequest("GET", "https://secure"))
def test_identity_config(): param_name, param_value = "foo", "bar" access_token = "****" expires_on = 42 expected_token = AccessToken(access_token, expires_on) scope = "scope" transport = validating_transport( requests=[ Request(base_url=IMDS_URL), Request( base_url=IMDS_URL, method="GET", required_headers={"Metadata": "true", "User-Agent": USER_AGENT}, required_params={"api-version": "2018-02-01", "resource": scope, param_name: param_value}, ), ], responses=[ mock_response(status_code=400, json_payload={"error": "this is an error message"}), mock_response( json_payload={ "access_token": access_token, "expires_in": 42, "expires_on": expires_on, "ext_expires_in": 42, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", } ), ], ) credential = ImdsCredential(identity_config={param_name: param_value}, transport=transport) token = credential.get_token(scope) assert token == expected_token
def __init__(self, arguments): super().__init__(arguments) token = AccessToken("**", int(time.time() + 3600)) self.request = HttpRequest("GET", "https://localhost") credential = Mock(get_token=Mock(return_value=token)) self.pipeline = Pipeline( transport=Mock(), policies=[BearerTokenCredentialPolicy(credential=credential)]) get_token_future = asyncio.Future() get_token_future.set_result(token) async_credential = Mock(get_token=Mock(return_value=get_token_future)) send_future = asyncio.Future() send_future.set_result(Mock()) async_transport = Mock(send=Mock(return_value=send_future)) self.async_pipeline = AsyncPipeline( async_transport, policies=[ AsyncBearerTokenCredentialPolicy(credential=async_credential) ])
async def test_bearer_policy_adds_header(): """The bearer token policy should add a header containing a token from its credential""" expected_token = AccessToken("expected_token", 0) async def verify_authorization_header(request): assert request.http_request.headers[ "Authorization"] == "Bearer {}".format(expected_token.token) get_token_calls = 0 async def get_token(_): nonlocal get_token_calls get_token_calls += 1 return expected_token fake_credential = Mock(get_token=get_token) policies = [ AsyncBearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header) ] pipeline = AsyncPipeline(transport=Mock(), policies=policies) await pipeline.run(HttpRequest("GET", "https://spam.eggs"), context=None) assert get_token_calls == 1
def test_caching_when_only_expires_in_set(): """the cache should function when auth responses don't include an explicit expires_on""" access_token = "token" now = 42 expires_in = 1800 expires_on = now + expires_in expected_token = AccessToken(access_token, expires_on) mock_send = Mock( return_value=mock_response( json_payload={"access_token": access_token, "expires_in": expires_in, "token_type": "Bearer"} ) ) client = AuthnClient(endpoint="http://foo", transport=Mock(send=mock_send)) with patch("azure.identity._authn_client.time.time") as mock_time: mock_time.return_value = 42 token = client.request_token(["scope"]) assert token.token == expected_token.token assert token.expires_on == expected_token.expires_on cached_token = client.get_cached_token(["scope"]) assert cached_token == expected_token
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # type: (*str, **Any) -> AccessToken """Request an access token for `scopes`. .. note:: This method is called by Azure SDK clients. It isn't intended for use in application code. :param str scopes: desired scopes for the token :rtype: :class:`azure.core.credentials.AccessToken` :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` attribute gives a reason. Any error response from Azure Active Directory is available as the error's ``response`` attribute. """ # MSAL requires scopes be a list scopes = list(scopes) # type: ignore now = int(time.time()) app = self._get_app() accounts = app.get_accounts(username=self._username) result = None for account in accounts: result = app.acquire_token_silent(scopes, account=account) if result: break if not result: # cache miss -> request a new token with self._adapter: result = app.acquire_token_by_username_password( username=self._username, password=self._password, scopes=scopes ) if "access_token" not in result: raise ClientAuthenticationError(message="authentication failed: {}".format(result.get("error_description"))) return AccessToken(result["access_token"], now + int(result["expires_in"]))
def __init__(self): self.token = AccessToken("YOU SHALL NOT PASS", 0)
def get_token(self, *scopes, **kwargs): r = requests.post(AUTHSERVER + "/oauth2/v2.0/token") access_token = r.json()["access_token"] return AccessToken(access_token, 1)
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # type: (str, Any) -> AccessToken """ This method is automatically called when token is about to expire. """ return AccessToken(self.token, self.expiry)
def setUpClass(cls, credential): credential.get_token = Mock( return_value=AccessToken("some_token", datetime.now().replace(tzinfo=TZ_UTC))) TestChatThreadClient.credential = credential
def get_token(self, *scopes, **kwargs): """ This method is automatically called when token is about to expire. """ return AccessToken(self.token, self.expiry)
def create_credential(self): if self.is_live: return EnvironmentCredential() return Mock(get_token=lambda *_: get_completed_future( AccessToken("fake-token", 0)))
def create_vault_client(self, vault_uri): if self.is_live: credential = EnvironmentCredential() else: credential = Mock(get_token=lambda _: AccessToken("fake-token", 0)) return VaultClient(vault_uri, credential)
def mock_azure_identity_TokenCredential(self, mocker): mock = mocker.MagicMock() mock.get_token.return_value = AccessToken(fake_token, fake_token_expiry) return mock
def setUpClass(cls, credential): credential.get_token = Mock(return_value=AccessToken( "some_token", _convert_datetime_to_utc_int(datetime.now().replace( tzinfo=TZ_UTC)))) TestChatClient.credential = credential
def get_token(self, *scopes, **kwargs): from azure.core.credentials import AccessToken return AccessToken("fake-token", 0)
def __init__(self): self.token = AccessToken("Fake Token", 0)