def _create_appconfig_pipeline( self, credential, base_url=None, aad_mode=False, **kwargs ): transport = kwargs.get("transport") policies = kwargs.get("policies") if policies is None: # [] is a valid policy list if aad_mode: scope = base_url.strip("/") + "/.default" if hasattr(credential, "get_token"): credential_policy = BearerTokenCredentialPolicy(credential, scope) else: raise TypeError( "Please provide an instance from azure-identity " "or a class that implement the 'get_token protocol" ) else: credential_policy = AppConfigRequestsCredentialsPolicy(credential) policies = [ self._config.headers_policy, self._config.user_agent_policy, self._config.retry_policy, self._sync_token_policy, credential_policy, self._config.logging_policy, # HTTP request/response log DistributedTracingPolicy(**kwargs), HttpLoggingPolicy(**kwargs), ContentDecodePolicy(**kwargs), ] if not transport: transport = RequestsTransport(**kwargs) return Pipeline(transport, policies)
def _build_pipeline(self, **kwargs): # pylint: disable=no-self-use transport = kwargs.get('transport') policies = kwargs.get('policies') credential_policy = \ AsyncServiceBusSharedKeyCredentialPolicy(self._endpoint, self._credential, "Authorization") \ if isinstance(self._credential, ServiceBusSharedKeyCredential) \ else AsyncBearerTokenCredentialPolicy(self._credential, JWT_TOKEN_SCOPE) if policies is None: # [] is a valid policy list policies = [ RequestIdPolicy(**kwargs), self._config.headers_policy, self._config.user_agent_policy, self._config.proxy_policy, ContentDecodePolicy(**kwargs), ServiceBusXMLWorkaroundPolicy(), self._config.redirect_policy, self._config.retry_policy, credential_policy, self._config.logging_policy, DistributedTracingPolicy(**kwargs), HttpLoggingPolicy(**kwargs), ] if not transport: transport = AioHttpTransport(**kwargs) return AsyncPipeline(transport, policies)
def _create_pipeline(self, credential, **kwargs): # type: (Any, **Any) -> Tuple[Configuration, Pipeline] self._credential_policy = None if hasattr(credential, "get_token"): self._credential_policy = BearerTokenCredentialPolicy( credential, STORAGE_OAUTH_SCOPE) elif isinstance(credential, SharedKeyCredentialPolicy): self._credential_policy = credential elif credential is not None: raise TypeError("Unsupported credential: {}".format(credential)) config = kwargs.get("_configuration") or create_configuration(**kwargs) if kwargs.get("_pipeline"): return config, kwargs["_pipeline"] config.transport = kwargs.get("transport") # type: ignore kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT) kwargs.setdefault("read_timeout", READ_TIMEOUT) if not config.transport: config.transport = RequestsTransport(**kwargs) policies = [ QueueMessagePolicy(), config.headers_policy, config.proxy_policy, config.user_agent_policy, StorageContentValidation(), StorageRequestHook(**kwargs), self._credential_policy, ContentDecodePolicy(response_encoding="utf-8"), RedirectPolicy(**kwargs), StorageHosts(hosts=self._hosts, **kwargs), config.retry_policy, config.logging_policy, StorageResponseHook(**kwargs), DistributedTracingPolicy(**kwargs), HttpLoggingPolicy(**kwargs) ] if kwargs.get("_additional_pipeline_policies"): policies = policies + kwargs.get("_additional_pipeline_policies") return config, Pipeline(config.transport, policies=policies)
def __init__(self, endpoint, client_cls, config=None, client_id=None, **kwargs): # type: (str, Type, Optional[Configuration], Optional[str], **Any) -> None self._identity_config = kwargs.pop("identity_config", None) or {} if client_id: if os.environ.get( EnvironmentVariables.MSI_ENDPOINT) and os.environ.get( EnvironmentVariables.MSI_SECRET): # App Service: version 2017-09-1 accepts client ID as parameter "clientid" if "clientid" not in self._identity_config: self._identity_config["clientid"] = client_id elif "client_id" not in self._identity_config: self._identity_config["client_id"] = client_id config = config or self._create_config(**kwargs) policies = [ ContentDecodePolicy(), config.headers_policy, config.user_agent_policy, config.retry_policy, config.logging_policy, DistributedTracingPolicy(**kwargs), HttpLoggingPolicy(**kwargs), ] self._client = client_cls(endpoint=endpoint, config=config, policies=policies, **kwargs)
def create_pipeline(credential, **kwargs): # type: (Any, **Any) -> Tuple[Configuration, Pipeline] credential_policy = None if hasattr(credential, 'get_token'): credential_policy = BearerTokenCredentialPolicy( credential, STORAGE_OAUTH_SCOPE) elif isinstance(credential, SharedKeyCredentialPolicy): credential_policy = credential elif credential is not None: raise TypeError("Unsupported credential: {}".format(credential)) config = kwargs.get('_configuration') or create_configuration(**kwargs) if kwargs.get('_pipeline'): return config, kwargs['_pipeline'] if 'connection_timeout' not in kwargs: kwargs['connection_timeout'] = DEFAULT_SOCKET_TIMEOUT transport = kwargs.get('transport') # type: HttpTransport if not transport: transport = RequestsTransport(**kwargs) policies = [ QueueMessagePolicy(), config.headers_policy, config.user_agent_policy, StorageContentValidation(), StorageRequestHook(**kwargs), credential_policy, ContentDecodePolicy(), config.redirect_policy, StorageHosts(**kwargs), config.retry_policy, config.logging_policy, StorageResponseHook(**kwargs), ] return config, Pipeline(transport, policies=policies)
def process_batch_error(error): """Raise detailed error message for HttpResponseErrors """ raise_error = HttpResponseError if error.status_code == 401: raise_error = ClientAuthenticationError error_message = error.message error_code = error.status_code error_body = None try: error_body = ContentDecodePolicy.deserialize_from_http_generics( error.response) except DecodeError: pass try: if error_body is not None: error_resp = error_body["error"] if "innerError" in error_resp: error_resp = error_resp["innerError"] error_message = error_resp["message"] error_code = error_resp["code"] error_message += "\nErrorCode:{}".format(error_code) except KeyError: raise HttpResponseError( message="There was an unknown error with the request.") error = raise_error(message=error_message, response=error.response) error.error_code = error_code raise error
def _create_pipeline(self, credential, **kwargs): # type: (Any, **Any) -> Tuple[Configuration, Pipeline] self._credential_policy = None if hasattr(credential, "get_token"): self._credential_policy = BearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) elif isinstance(credential, SharedKeyCredentialPolicy): self._credential_policy = credential elif credential is not None: raise TypeError("Unsupported credential: {}".format(credential)) config = kwargs.get("_configuration") or create_configuration(**kwargs) if kwargs.get("_pipeline"): return config, kwargs["_pipeline"] config.transport = kwargs.get("transport") # type: ignore if "connection_timeout" not in kwargs: kwargs["connection_timeout"] = DEFAULT_SOCKET_TIMEOUT if not config.transport: config.transport = RequestsTransport(**kwargs) policies = [ QueueMessagePolicy(), config.headers_policy, config.user_agent_policy, StorageContentValidation(), StorageRequestHook(**kwargs), self._credential_policy, ContentDecodePolicy(), RedirectPolicy(**kwargs), StorageHosts(hosts=self._hosts, **kwargs), config.retry_policy, config.logging_policy, StorageResponseHook(**kwargs), DistributedTracingPolicy(), ] return config, Pipeline(config.transport, policies=policies)
def _process_response(self, response, request_time): # type: (PipelineResponse, int) -> AccessToken # ContentDecodePolicy sets this, and should have raised if it couldn't deserialize the response content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response) # type: dict if not content: raise ClientAuthenticationError(message="No token received.", response=response.http_response) if "access_token" not in content or not ("expires_in" in content or "expires_on" in content): if content and "access_token" in content: content["access_token"] = "****" raise ClientAuthenticationError( message='Unexpected response "{}"'.format(content), response=response.http_response ) if self._content_callback: self._content_callback(content) expires_on = int(content.get("expires_on") or int(content["expires_in"]) + request_time) content["expires_on"] = expires_on token = AccessToken(content["access_token"], content["expires_on"]) # caching is the final step because TokenCache.add mutates its "event" self._cache.add( event={"response": content, "scope": [content["resource"]]}, now=request_time, ) return token
def _process_response(self, response, request_time): # type: (PipelineResponse, int) -> AccessToken content = response.context.get( ContentDecodePolicy.CONTEXT_NAME ) or ContentDecodePolicy.deserialize_from_http_generics( response.http_response) if response.http_request.body.get("grant_type") == "refresh_token": if content.get("error") == "invalid_grant": # the request's refresh token is invalid -> evict it from the cache cache_entries = self._cache.find( TokenCache.CredentialType.REFRESH_TOKEN, query={ "secret": response.http_request.body["refresh_token"] }, ) for invalid_token in cache_entries: self._cache.remove_rt(invalid_token) if "refresh_token" in content: # AAD returned a new refresh token -> update the cache entry cache_entries = self._cache.find( TokenCache.CredentialType.REFRESH_TOKEN, query={ "secret": response.http_request.body["refresh_token"] }, ) # If the old token is in multiple cache entries, the cache is in a state we don't # expect or know how to reason about, so we update nothing. if len(cache_entries) == 1: self._cache.update_rt(cache_entries[0], content["refresh_token"]) del content[ "refresh_token"] # prevent caching a redundant entry _raise_for_error(response, content) if "expires_on" in content: expires_on = int(content["expires_on"]) elif "expires_in" in content: expires_on = request_time + int(content["expires_in"]) else: _scrub_secrets(content) raise ClientAuthenticationError( message="Unexpected response from Azure Active Directory: {}". format(content)) token = AccessToken(content["access_token"], expires_on) # caching is the final step because 'add' mutates 'content' self._cache.add( event={ "client_id": self._client_id, "response": content, "scope": response.http_request.body["scope"].split(), "token_endpoint": response.http_request.url, }, now=request_time, ) return token
def _configure_policies(self, **kwargs): # type: (**Any) -> None try: from azure.core.pipeline.transport import AioHttpTransport if not kwargs.get("transport"): kwargs.setdefault("transport", AioHttpTransport(**kwargs)) except ImportError: raise ImportError( "Unable to create async transport. Please check aiohttp is installed." ) kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT) kwargs.setdefault("read_timeout", READ_TIMEOUT) self._policies = [ StorageHeadersPolicy(**kwargs), ProxyPolicy(**kwargs), UserAgentPolicy(sdk_moniker=SDK_MONIKER, **kwargs), StorageContentValidation(), StorageRequestHook(**kwargs), self._credential_policy, ContentDecodePolicy(response_encoding="utf-8"), AsyncRedirectPolicy(**kwargs), StorageHosts(hosts=self._hosts, **kwargs), AsyncTablesRetryPolicy(**kwargs), StorageLoggingPolicy(**kwargs), AsyncStorageResponseHook(**kwargs), DistributedTracingPolicy(**kwargs), HttpLoggingPolicy(**kwargs), ]
async def record_wrap(*args, **kwargs): def transform_args(*args, **kwargs): copied_positional_args = list(args) request = copied_positional_args[1] transform_request(request, recording_id) return tuple(copied_positional_args), kwargs trimmed_kwargs = {k: v for k, v in kwargs.items()} trim_kwargs_from_test_function(test_func, trimmed_kwargs) if is_live_and_not_recording(): return await test_func(*args, **trimmed_kwargs) test_id = get_test_id() recording_id, variables = start_record_or_playback(test_id) original_transport_func = AioHttpTransport.send async def combined_call(*args, **kwargs): adjusted_args, adjusted_kwargs = transform_args(*args, **kwargs) result = await original_transport_func(*adjusted_args, **adjusted_kwargs) # make the x-recording-upstream-base-uri the URL of the request # this makes the request look like it was made to the original endpoint instead of to the proxy # without this, things like LROPollers can get broken by polling the wrong endpoint parsed_result = url_parse.urlparse(result.request.url) upstream_uri = url_parse.urlparse(result.request.headers["x-recording-upstream-base-uri"]) upstream_uri_dict = {"scheme": upstream_uri.scheme, "netloc": upstream_uri.netloc} original_target = parsed_result._replace(**upstream_uri_dict).geturl() result.request.url = original_target return result AioHttpTransport.send = combined_call # call the modified function # we define test_output before invoking the test so the variable is defined in case of an exception test_output = None try: try: test_output = await test_func(*args, variables=variables, **trimmed_kwargs) except TypeError: logger = logging.getLogger() logger.info( "This test can't accept variables as input. The test method should accept `**kwargs` and/or a " "`variables` parameter to make use of recorded test variables." ) test_output = await test_func(*args, **trimmed_kwargs) except ResourceNotFoundError as error: error_body = ContentDecodePolicy.deserialize_from_http_generics(error.response) message = error_body.get("message") or error_body.get("Message") error_with_message = ResourceNotFoundError(message=message, response=error.response) raise error_with_message from error finally: AioHttpTransport.send = original_transport_func stop_record_or_playback(test_id, recording_id, test_output) return test_output
def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs): # type: (Sequence[str], str, **Any) -> AccessToken request = self._get_refresh_token_request(scopes, refresh_token) now = int(time.time()) response = self._pipeline.run(request, stream=False, **kwargs) content = ContentDecodePolicy.deserialize_from_http_generics( response.http_response) return self._process_response(response=content, scopes=scopes, now=now)
def __init__(self, config=None, **kwargs): # type: (Optional[Configuration], Dict[str, Any]) -> None config = config or self.create_config(**kwargs) policies = [ config.header_policy, ContentDecodePolicy(), config.logging_policy, config.retry_policy ] self._client = AuthnClient(Endpoints.IMDS, config, policies, **kwargs)
async def obtain_token_by_refresh_token(self, scopes: "Sequence[str]", refresh_token: str, **kwargs: "Any") -> "AccessToken": request = self._get_refresh_token_request(scopes, refresh_token) now = int(time.time()) response = await self._pipeline.run(request, **kwargs) content = ContentDecodePolicy.deserialize_from_http_generics( response.http_response) return self._process_response(response=content, scopes=scopes, now=now)
async def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs): # type: (Sequence[str], AadClientCertificate, **Any) -> AccessToken request = self._get_client_certificate_request(scopes, certificate) now = int(time.time()) response = await self._pipeline.run(request, stream=False, **kwargs) content = ContentDecodePolicy.deserialize_from_http_generics( response.http_response) return self._process_response(response=content, scopes=scopes, now=now)
def _create_pipeline(account, credential, **kwargs): # type: (Any, **Any) -> Tuple[Configuration, Pipeline] credential_policy = SharedKeyCredentialPolicy(account_name=account.name, account_key=credential) transport = RequestsTransport(**kwargs) policies = [ HeadersPolicy(), credential_policy, ContentDecodePolicy(response_encoding="utf-8")] return Pipeline(transport, policies=policies)
def client(cookie_policy): """Create a AutoRestHttpInfrastructureTestService client with test server credentials.""" policies = [ HeadersPolicy(), ContentDecodePolicy(), RedirectPolicy(), RetryPolicy(), cookie_policy ] with AutoRestHttpInfrastructureTestService( base_url="http://localhost:3000", policies=policies) as client: yield client
def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs): config = config or self._create_config(**kwargs) policies = policies or [ ContentDecodePolicy(), config.retry_policy, config.logging_policy, DistributedTracingPolicy(), ] if not transport: transport = RequestsTransport(**kwargs) return Pipeline(transport=transport, policies=policies)
def __init__(self, config=None, **kwargs): # type: (Optional[Configuration], Dict[str, Any]) -> None config = config or self.create_config(**kwargs) policies = [ ContentDecodePolicy(), config.retry_policy, config.logging_policy ] endpoint = os.environ.get(MSI_ENDPOINT) if not endpoint: raise ValueError( "expected environment variable {} has no value".format( MSI_ENDPOINT)) self._client = AuthnClient(endpoint, config, policies, **kwargs)
def _process_response(self, response, request_time): # type: (PipelineResponse, int) -> AccessToken content = response.context.get(ContentDecodePolicy.CONTEXT_NAME) if not content: try: content = ContentDecodePolicy.deserialize_from_text( response.http_response.text(), mime_type="application/json") except DecodeError as ex: if response.http_response.content_type.startswith( "application/json"): message = "Failed to deserialize JSON from response" else: message = 'Unexpected content type "{}"'.format( response.http_response.content_type) six.raise_from( ClientAuthenticationError(message=message, response=response.http_response), ex) if not content: raise ClientAuthenticationError(message="No token received.", response=response.http_response) if "access_token" not in content or not ("expires_in" in content or "expires_on" in content): if content and "access_token" in content: content["access_token"] = "****" raise ClientAuthenticationError( message='Unexpected response "{}"'.format(content), response=response.http_response) if self._content_callback: self._content_callback(content) expires_on = int( content.get("expires_on") or int(content["expires_in"]) + request_time) content["expires_on"] = expires_on token = AccessToken(content["access_token"], content["expires_on"]) # caching is the final step because TokenCache.add mutates its "event" self._cache.add( event={ "response": content, "scope": [content["resource"]] }, now=request_time, ) return token
def get_exception_for_key_vault_error(cls, response): # type: (Type[AzureError], HttpResponse) -> AzureError try: body = ContentDecodePolicy.deserialize_from_http_generics(response) message = "({}) {}".format(body["error"]["code"], body["error"]["message"]) except (DecodeError, KeyError): # Key Vault error response bodies have the expected shape and should be deserializable. # If we somehow land here, we'll take HttpResponse's default message. message = None return cls(message=message, response=response)
def on_response(self, request, response): if self._response_callback: data = ContentDecodePolicy.deserialize_from_http_generics( response.http_response) statistics = data.get("statistics", None) model_version = data.get("modelVersion", None) batch_statistics = TextDocumentBatchStatistics._from_generated( statistics) # pylint: disable=protected-access response.statistics = batch_statistics response.model_version = model_version response.raw_response = data self._response_callback(response)
def __init__(self, configuration=None, **kwargs): config = configuration or FooServiceClient.create_config(**kwargs) transport = kwargs.get('transport', RequestsTransport(**kwargs)) policies = [ config.user_agent_policy, config.headers_policy, config.authentication_policy, ContentDecodePolicy(), config.redirect_policy, config.retry_policy, config.logging_policy, ] self._pipeline = Pipeline(transport, policies=policies)
async def test_policies_configurable(): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) async def send(*_, **__): return mock_response(json_payload=build_aad_response(access_token="**")) credential = CertificateCredential( "tenant-id", "client-id", PEM_CERT_PATH, policies=[ContentDecodePolicy(), policy], transport=Mock(send=send) ) await credential.get_token("scope") assert policy.on_request.called
def __init__(self, auth_url, config=None, policies=None, transport=None, **kwargs): # type: (str, Optional[Configuration], Optional[Iterable[HTTPPolicy]], Optional[HttpTransport], Mapping[str, Any]) -> None config = config or self._create_config(**kwargs) policies = policies or [ ContentDecodePolicy(), config.retry_policy, config.logging_policy, DistributedTracingPolicy(), ] if not transport: transport = RequestsTransport(**kwargs) self._pipeline = Pipeline(transport=transport, policies=policies) super(AuthnClient, self).__init__(auth_url, **kwargs)
def __init__(self, endpoint, client_cls, config=None, client_id=None, **kwargs): # type: (str, Type, Optional[Configuration], Optional[str], Any) -> None self._client_id = client_id config = config or self._create_config(**kwargs) policies = [ ContentDecodePolicy(), config.headers_policy, config.retry_policy, config.logging_policy ] self._client = client_cls(endpoint, config, policies, **kwargs)
def __init__(self, auth_url: str, config: Optional[Configuration] = None, policies: Optional[Iterable[HTTPPolicy]] = None, transport: Optional[AsyncHttpTransport] = None, **kwargs: Mapping[str, Any]) -> None: config = config or self.create_config(**kwargs) policies = policies or [ ContentDecodePolicy(), config.logging_policy, config.retry_policy ] if not transport: transport = AsyncioRequestsTransport(configuration=config) self._pipeline = AsyncPipeline(transport=transport, policies=policies) super(AsyncAuthnClient, self).__init__(auth_url, **kwargs)
def test_policies_configurable(): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) transport = msal_validating_transport( requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token="**"))] ) credential = CertificateCredential( "tenant-id", "client-id", CERT_PATH, policies=[ContentDecodePolicy(), policy], transport=transport ) credential.get_token("scope") assert policy.on_request.called
def _get_exception_for_key_vault_error(cls, response): # type: (Type[HttpResponseError], HttpResponse) -> HttpResponseError """Construct cls (HttpResponseError or subclass thereof) with Key Vault's error message.""" try: body = ContentDecodePolicy.deserialize_from_http_generics(response) message = "({}) {}".format( body["error"]["code"], body["error"]["message"]) # type: Optional[str] except (DecodeError, KeyError): # Key Vault error response bodies should have the expected shape and be deserializable. # If we somehow land here, we'll take HttpResponse's default message. message = None return cls(message=message, response=response)
def send(self, request): response_callback = request.context.options.pop("response_hook", self._response_callback) if response_callback: response = self.next.send(request) data = ContentDecodePolicy.deserialize_from_http_generics(response.http_response) statistics = data.get("statistics", None) model_version = data.get("modelVersion", None) batch_statistics = TextDocumentBatchStatistics._from_generated(statistics) # pylint: disable=protected-access response.statistics = batch_statistics response.model_version = model_version response.raw_response = data response_callback(response) return response return self.next.send(request)