Example #1
0
async def test_sans_io_exception():
    class BrokenSender(AsyncHttpTransport):
        async def send(self, request, **config):
            raise ValueError("Broken")

        async def open(self):
            self.session = requests.Session()

        async def close(self):
            self.session.close()

        async def __aexit__(self, exc_type, exc_value, traceback):
            """Raise any exception triggered within the runtime context."""
            return self.close()

    pipeline = AsyncPipeline(BrokenSender(), [SansIOHTTPPolicy()])

    req = HttpRequest('GET', '/')
    with pytest.raises(ValueError):
        await pipeline.run(req)

    class SwapExec(SansIOHTTPPolicy):
        def on_exception(self, requests, **kwargs):
            exc_type, exc_value, exc_traceback = sys.exc_info()
            raise NotImplementedError(exc_value)

    pipeline = AsyncPipeline(BrokenSender(), [SwapExec()])
    with pytest.raises(NotImplementedError):
        await pipeline.run(req)
async def test_bearer_policy_token_caching():
    good_for_one_hour = AccessToken("token", time.time() + 3600)
    expected_token = good_for_one_hour
    get_token_calls = 0

    async def get_token(_):
        nonlocal get_token_calls
        get_token_calls += 1
        return expected_token

    credential = Mock(get_token=get_token)
    policies = [AsyncBearerTokenCredentialPolicy(credential, "scope"), Mock(send=asyncio.coroutine(lambda _: Mock()))]
    pipeline = AsyncPipeline(transport=Mock, policies=policies)

    await pipeline.run(HttpRequest("GET", "https://spam.eggs"))
    assert get_token_calls == 1  # policy has no token at first request -> it should call get_token

    await pipeline.run(HttpRequest("GET", "https://spam.eggs"))
    assert get_token_calls == 1  # token is good for an hour -> policy should return it from cache

    expired_token = AccessToken("token", time.time())
    get_token_calls = 0
    expected_token = expired_token
    policies = [AsyncBearerTokenCredentialPolicy(credential, "scope"), Mock(send=asyncio.coroutine(lambda _: Mock()))]
    pipeline = AsyncPipeline(transport=Mock(), policies=policies)

    await pipeline.run(HttpRequest("GET", "https://spam.eggs"))
    assert get_token_calls == 1

    await pipeline.run(HttpRequest("GET", "https://spam.eggs"))
    assert get_token_calls == 2  # token expired -> policy should call get_token
Example #3
0
async def test_bearer_policy_calls_sansio_methods():
    """AsyncBearerTokenCredentialPolicy should call SansIOHttpPolicy methods as does _SansIOAsyncHTTPPolicyRunner"""
    class TestPolicy(AsyncBearerTokenCredentialPolicy):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.on_exception = Mock(return_value=False)
            self.on_request = Mock()
            self.on_response = Mock()

        async def send(self, request):
            self.request = request
            self.response = await super().send(request)
            return self.response

    credential = Mock(get_token=Mock(return_value=get_completed_future(
        AccessToken("***",
                    int(time.time()) + 3600))))
    policy = TestPolicy(credential, "scope")
    transport = Mock(send=Mock(
        return_value=get_completed_future(Mock(status_code=200))))

    pipeline = AsyncPipeline(transport=transport, policies=[policy])
    await pipeline.run(HttpRequest("GET", "https://localhost"))

    policy.on_request.assert_called_once_with(policy.request)
    policy.on_response.assert_called_once_with(policy.request, policy.response)

    # the policy should call on_exception when next.send() raises
    class TestException(Exception):
        pass

    # during the first send...
    transport = Mock(send=Mock(side_effect=TestException))
    policy = TestPolicy(credential, "scope")
    pipeline = AsyncPipeline(transport=transport, policies=[policy])
    with pytest.raises(TestException):
        await pipeline.run(HttpRequest("GET", "https://localhost"))
    policy.on_exception.assert_called_once_with(policy.request)

    # ...or the second
    async def fake_send(*args, **kwargs):
        if fake_send.calls == 0:
            fake_send.calls = 1
            return Mock(
                status_code=401,
                headers={"WWW-Authenticate": 'Basic realm="localhost"'})
        raise TestException()

    fake_send.calls = 0

    policy = TestPolicy(credential, "scope")
    policy.on_challenge = Mock(return_value=get_completed_future(True))
    transport = Mock(send=Mock(wraps=fake_send))
    pipeline = AsyncPipeline(transport=transport, policies=[policy])
    with pytest.raises(TestException):
        await pipeline.run(HttpRequest("GET", "https://localhost"))
    assert transport.send.call_count == 2
    policy.on_challenge.assert_called_once()
    policy.on_exception.assert_called_once_with(policy.request)
async def test_response_streaming_error_behavior():
    # Test to reproduce https://github.com/Azure/azure-sdk-for-python/issues/16723
    block_size = 103
    total_response_size = 500
    req_response = requests.Response()
    req_request = requests.Request()

    class FakeStreamWithConnectionError:
        # fake object for urllib3.response.HTTPResponse
        def __init__(self):
            self.total_response_size = 500

        def stream(self, chunk_size, decode_content=False):
            assert chunk_size == block_size
            left = total_response_size
            while left > 0:
                if left <= block_size:
                    raise requests.exceptions.ConnectionError()
                data = b"X" * min(chunk_size, left)
                left -= len(data)
                yield data

        def read(self, chunk_size, decode_content=False):
            assert chunk_size == block_size
            if self.total_response_size > 0:
                if self.total_response_size <= block_size:
                    raise requests.exceptions.ConnectionError()
                data = b"X" * min(chunk_size, self.total_response_size)
                self.total_response_size -= len(data)
                return data

        def close(self):
            pass

    req_response.raw = FakeStreamWithConnectionError()

    response = AsyncioRequestsTransportResponse(
        req_request,
        req_response,
        block_size,
    )

    async def mock_run(self, *args, **kwargs):
        return PipelineResponse(
            None,
            requests.Response(),
            None,
        )

    transport = AioHttpTransport()
    pipeline = AsyncPipeline(transport)
    pipeline.run = mock_run
    downloader = response.stream_download(pipeline)
    with pytest.raises(requests.exceptions.ConnectionError):
        while True:
            await downloader.__anext__()
Example #5
0
    def _create_appconfig_pipeline(self, credential, base_url=None, aad_mode=False, **kwargs):
        transport = kwargs.get('transport')
        policies = kwargs.get('policies')

        if policies is None:  # [] is a valid policy list
            if aad_mode:
                scope = base_url.strip("/") + "/.default"
                if hasattr(credential, "get_token"):
                    credential_policy = AsyncBearerTokenCredentialPolicy(credential, scope)
                else:
                    raise TypeError("Please provide an instance from azure-identity "
                                    "or a class that implement the 'get_token protocol")
            else:
                credential_policy = AppConfigRequestsCredentialsPolicy(credential)

            policies = [
                self._config.headers_policy,
                self._config.user_agent_policy,
                credential_policy,
                self._config.retry_policy,
                SyncTokenPolicy(),
                self._config.logging_policy,  # HTTP request/response log
                DistributedTracingPolicy(**kwargs),
                HttpLoggingPolicy(**kwargs),
            ]

        if not transport:
            transport = AsyncioRequestsTransport(**kwargs)

        return AsyncPipeline(
            transport,
            policies,
        )
    def _create_pipeline(self, credential, **kwargs):
        credential_policy = None
        if credential is None:
            raise ValueError("Parameter 'credential' must not be None.")
        if hasattr(credential, "get_token"):
            credential_policy = AsyncBearerTokenCredentialPolicy(
                credential, "https://cognitiveservices.azure.com/.default")
        elif isinstance(credential, six.string_types):
            credential_policy = CognitiveServicesCredentialPolicy(credential)
        elif credential is not None:
            raise TypeError("Unsupported credential: {}".format(credential))

        config = self._create_configuration(**kwargs)
        config.transport = kwargs.get("transport")  # type: ignore
        if not config.transport:
            try:
                from azure.core.pipeline.transport import AioHttpTransport
            except ImportError:
                raise ImportError(
                    "Unable to create async transport. Please check aiohttp is installed."
                )
            config.transport = AioHttpTransport(**kwargs)

        policies = [
            config.headers_policy, config.user_agent_policy,
            RequestIdPolicy(**kwargs), config.proxy_policy,
            AsyncRedirectPolicy(**kwargs),
            AsyncRetryPolicy(**kwargs), credential_policy,
            config.logging_policy,
            AsyncTextAnalyticsResponseHook(**kwargs),
            DistributedTracingPolicy(**kwargs),
            HttpLoggingPolicy(**kwargs)
        ]
        return AsyncPipeline(config.transport, policies=policies)
async def test_bearer_policy_adds_header():
    """The bearer token policy should add a header containing a token from its credential"""
    # 2524608000 == 01/01/2050 @ 12:00am (UTC)
    expected_token = AccessToken("expected_token", 2524608000)

    async def verify_authorization_header(request):
        assert request.http_request.headers[
            "Authorization"] == "Bearer {}".format(expected_token.token)
        return Mock()

    get_token_calls = 0

    async def get_token(_):
        nonlocal get_token_calls
        get_token_calls += 1
        return expected_token

    fake_credential = Mock(get_token=get_token)
    policies = [
        AsyncBearerTokenCredentialPolicy(fake_credential, "scope"),
        Mock(send=verify_authorization_header)
    ]
    pipeline = AsyncPipeline(transport=Mock(), policies=policies)

    await pipeline.run(HttpRequest("GET", "https://spam.eggs"), context=None)
    assert get_token_calls == 1

    await pipeline.run(HttpRequest("GET", "https://spam.eggs"), context=None)
    # Didn't need a new token
    assert get_token_calls == 1
    def get_table_client(
            self, table_name, # type: str
            **kwargs # type: Optional[Any]
    ):
        # type: (...) -> TableClient
        """Get a client to interact with the specified table.

        The table need not already exist.

        :param table:
            The queue. This can either be the name of the queue,
            or an instance of QueueProperties.
        :type table: str or ~azure.storage.table.TableProperties
        :returns: A :class:`~azure.data.tables.TableClient` object.
        :rtype: ~azure.data.tables.TableClient

        """

        _pipeline = AsyncPipeline(
            transport=AsyncTransportWrapper(self._pipeline._transport),  # pylint: disable = protected-access
            policies=self._pipeline._impl_policies  # pylint: disable = protected-access
        )

        return TableClient(
            self.url, table_name=table_name, credential=self.credential,
            key_resolver_function=self.key_resolver_function, require_encryption=self.require_encryption,
            key_encryption_key=self.key_encryption_key, api_version=self.api_version, _pipeline=self._pipeline,
            _configuration=self._config, _location_mode=self._location_mode, _hosts=self._hosts, **kwargs)
    def __init__(self,
                 config: "Optional[Configuration]" = None,
                 policies: "Optional[Iterable[AsyncHTTPPolicy]]" = None,
                 transport: "Optional[AsyncHttpTransport]" = None,
                 **kwargs: "Any") -> None:

        config = config or self._create_config(**kwargs)
        policies = policies or [
            config.retry_policy, config.logging_policy,
            DistributedTracingPolicy()
        ]
        self._transport = transport or AioHttpTransport(configuration=config)
        atexit.register(
            self._close_transport_session)  # prevent aiohttp warnings
        self._pipeline = AsyncPipeline(transport=self._transport,
                                       policies=policies)
Example #10
0
async def test_bearer_policy_optionally_enforces_https():
    """HTTPS enforcement should be controlled by a keyword argument, and enabled by default"""
    async def assert_option_popped(request, **kwargs):
        assert "enforce_https" not in kwargs, "AsyncBearerTokenCredentialPolicy didn't pop the 'enforce_https' option"

    credential = Mock(get_token=lambda *_, **__: get_completed_future(
        AccessToken("***", 42)))
    pipeline = AsyncPipeline(
        transport=Mock(send=assert_option_popped),
        policies=[AsyncBearerTokenCredentialPolicy(credential, "scope")])

    # by default and when enforce_https=True, the policy should raise when given an insecure request
    with pytest.raises(ServiceRequestError):
        await pipeline.run(HttpRequest("GET", "http://not.secure"))
    with pytest.raises(ServiceRequestError):
        await pipeline.run(HttpRequest("GET", "http://not.secure"),
                           enforce_https=True)

    # when enforce_https=False, an insecure request should pass
    await pipeline.run(HttpRequest("GET", "http://not.secure"),
                       enforce_https=False)

    # https requests should always pass
    await pipeline.run(HttpRequest("GET", "https://secure"),
                       enforce_https=False)
    await pipeline.run(HttpRequest("GET", "https://secure"),
                       enforce_https=True)
    await pipeline.run(HttpRequest("GET", "https://secure"))
Example #11
0
async def test_multiple_claims_challenges():
    """ARMChallengeAuthenticationPolicy should not attempt to handle a response having multiple claims challenges"""

    expected_header = ",".join((
        'Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", client_id="00000003-0000-0000-c000-000000000000", error="insufficient_claims", claims="eyJhY2Nlc3NfdG9rZW4iOiB7ImZvbyI6ICJiYXIifX0="',
        'Bearer authorization_uri="https://login.windows-ppe.net/", error="invalid_token", error_description="User session has been revoked", claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0="',
    ))

    async def send(request):
        return Mock(status_code=401,
                    headers={"WWW-Authenticate": expected_header})

    async def get_token(*_, **__):
        return AccessToken("***", 42)

    transport = Mock(send=Mock(wraps=send))
    credential = Mock(get_token=Mock(wraps=get_token))
    policies = [AsyncARMChallengeAuthenticationPolicy(credential, "scope")]
    pipeline = AsyncPipeline(transport=transport, policies=policies)

    response = await pipeline.run(HttpRequest("GET", "https://localhost"))

    assert transport.send.call_count == 1
    assert credential.get_token.call_count == 1

    # the policy should have returned the error response because it was unable to handle the challenge
    assert response.http_response.status_code == 401
    assert response.http_response.headers[
        "WWW-Authenticate"] == expected_header
Example #12
0
    def get_table_client(
        self,
        table_name,  # type: str
        **kwargs  # type: Optional[Any]
    ):
        # type: (...) -> TableClient
        """Get a client to interact with the specified table.

        The table need not already exist.

        :param str table_name: The table name
        :returns: A :class:`~azure.data.tables.aio.TableClient` object.
        :rtype: :class:`~azure.data.tables.aio.TableClient`

        """
        pipeline = AsyncPipeline(
            transport=AsyncTransportWrapper(self._client._client._pipeline._transport), # pylint:disable=protected-access
            policies=self._policies,
        )
        return TableClient(
            self.url,
            table_name=table_name,
            credential=self.credential,
            api_version=self.api_version,
            pipeline=pipeline,
            location_mode=self._location_mode,
            _hosts=self._hosts,
            **kwargs
        )
Example #13
0
    def get_subdirectory_client(self, directory_name, **kwargs):
        # type: (str, Any) -> ShareDirectoryClient
        """Get a client to interact with a specific subdirectory.

        The subdirectory need not already exist.

        :param str directory_name:
            The name of the subdirectory.
        :returns: A Directory Client.
        :rtype: ~azure.storage.fileshare.aio.ShareDirectoryClient

        .. admonition:: Example:

            .. literalinclude:: ../samples/file_samples_directory_async.py
                :start-after: [START get_subdirectory_client]
                :end-before: [END get_subdirectory_client]
                :language: python
                :dedent: 16
                :caption: Gets the subdirectory client.
        """
        directory_path = self.directory_path.rstrip('/') + "/" + directory_name

        _pipeline = AsyncPipeline(
            transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
            policies=self._pipeline._impl_policies # pylint: disable = protected-access
        )
        return ShareDirectoryClient(
            self.url, share_name=self.share_name, directory_path=directory_path, snapshot=self.snapshot,
            credential=self.credential, api_version=self.api_version, _hosts=self._hosts, _configuration=self._config,
            _pipeline=_pipeline, _location_mode=self._location_mode, loop=self._loop, **kwargs)
Example #14
0
async def test_preserves_options_and_headers():
    """After a challenge, the original request should be sent with its options and headers preserved.

    If a policy mutates the options or headers of the challenge (unauthorized) request, the options of the service
    request should be present when it is sent with authorization.
    """

    url = get_random_url()

    token = "**"

    async def get_token(*_, **__):
        return AccessToken(token, 0)

    credential = Mock(get_token=Mock(wraps=get_token))

    transport = async_validating_transport(
        requests=[Request()] * 2 + [Request(required_headers={"Authorization": "Bearer " + token})],
        responses=[
            mock_response(
                status_code=401, headers={"WWW-Authenticate": 'Bearer authorization="{}", resource=foo'.format(url)}
            )
        ]
        + [mock_response()] * 2,
    )
    challenge_policy = AsyncChallengeAuthPolicy(credential=credential)
    policies = get_policies_for_request_mutation_test(challenge_policy)
    pipeline = AsyncPipeline(policies=policies, transport=transport)

    response = await pipeline.run(HttpRequest("GET", url))

    # ensure the mock sans I/O policies were used
    for policy in policies:
        if hasattr(policy, "on_request"):
            assert policy.on_request.called, "mock policy wasn't invoked"
async def test_retry_timeout():
    class MockTransport(AsyncHttpTransport):
        def __init__(self):
            self.count = 0

        async def __aexit__(self, exc_type, exc_val, exc_tb):
            pass

        async def close(self):
            pass

        async def open(self):
            pass

        async def send(
                self, request,
                **kwargs):  # type: (PipelineRequest, Any) -> PipelineResponse
            self.count += 1
            if self.count > 2:
                assert self.count <= 2
            time.sleep(0.5)
            raise ServiceResponseError('timeout')

    http_request = HttpRequest('GET', 'http://127.0.0.1/')
    headers = {'Content-Type': "multipart/form-data"}
    http_request.headers = headers
    http_retry = AsyncRetryPolicy(retry_total=10, timeout=1)
    pipeline = AsyncPipeline(MockTransport(), [http_retry])
    with pytest.raises(ServiceResponseTimeoutError):
        await pipeline.run(http_request)
Example #16
0
async def test_bearer_policy_adds_header():
    """The bearer token policy should add a header containing a token from its credential"""
    expected_token = "expected_token"

    async def verify_authorization_header(request):
        assert request.http_request.headers[
            "Authorization"] == "Bearer {}".format(expected_token)

    get_token_calls = 0

    async def get_token(_):
        nonlocal get_token_calls
        get_token_calls += 1
        return expected_token

    fake_credential = Mock(get_token=get_token)
    policies = [
        AsyncBearerTokenCredentialPolicy(credential=fake_credential,
                                         scopes=("", )),
        Mock(spec=HTTPPolicy, send=verify_authorization_header),
    ]
    pipeline = AsyncPipeline(transport=Mock(spec=AsyncHttpTransport),
                             policies=policies)

    await pipeline.run(HttpRequest("GET", "https://spam.eggs"), context=None)
    assert get_token_calls == 1
Example #17
0
async def test_retry_seekable_stream():
    class MockTransport(AsyncHttpTransport):
        def __init__(self):
            self._first = True

        async def __aexit__(self, exc_type, exc_val, exc_tb):
            pass

        async def close(self):
            pass

        async def open(self):
            pass

        async def send(
                self, request,
                **kwargs):  # type: (PipelineRequest, Any) -> PipelineResponse
            if self._first:
                self._first = False
                request.body.seek(0, 2)
                raise AzureError('fail on first')
            position = request.body.tell()
            assert position == 0
            response = HttpResponse(request, None)
            response.status_code = 400
            return response

    data = BytesIO(b"Lots of dataaaa")
    http_request = HttpRequest('GET', 'http://127.0.0.1/')
    http_request.set_streamed_data_body(data)
    http_retry = AsyncRetryPolicy(retry_total=1)
    pipeline = AsyncPipeline(MockTransport(), [http_retry])
    await pipeline.run(http_request)
Example #18
0
async def test_retry_on_429():
    class MockTransport(AsyncHttpTransport):
        def __init__(self):
            self._count = 0

        async def __aexit__(self, exc_type, exc_val, exc_tb):
            pass

        async def close(self):
            pass

        async def open(self):
            pass

        async def send(
                self, request,
                **kwargs):  # type: (PipelineRequest, Any) -> PipelineResponse
            self._count += 1
            response = HttpResponse(request, None)
            response.status_code = 429
            return response

    http_request = HttpRequest('GET', 'http://127.0.0.1/')
    http_retry = AsyncRetryPolicy(retry_total=1)
    transport = MockTransport()
    pipeline = AsyncPipeline(transport, [http_retry])
    await pipeline.run(http_request)
    assert transport._count == 2
 async def do():
     conf = Configuration()
     request = HttpRequest("GET", "https://bing.com/")
     policies = [UserAgentPolicy("myusergant"), AsyncRedirectPolicy()]
     async with AsyncPipeline(TrioRequestsTransport(conf),
                              policies=policies) as pipeline:
         return await pipeline.run(request)
Example #20
0
    def get_file_client(self, file_name, **kwargs):
        # type: (str, Any) -> ShareFileClient
        """Get a client to interact with a specific file.

        The file need not already exist.

        :param str file_name:
            The name of the file.
        :returns: A File Client.
        :rtype: ~azure.storage.fileshare.ShareFileClient
        """
        if self.directory_path:
            file_name = self.directory_path.rstrip('/') + "/" + file_name

        _pipeline = AsyncPipeline(
            transport=AsyncTransportWrapper(
                self._pipeline._transport
            ),  # pylint: disable = protected-access
            policies=self._pipeline.
            _impl_policies  # pylint: disable = protected-access
        )
        return ShareFileClient(self.url,
                               file_path=file_name,
                               share_name=self.share_name,
                               snapshot=self.snapshot,
                               credential=self.credential,
                               _hosts=self._hosts,
                               _configuration=self._config,
                               _pipeline=_pipeline,
                               _location_mode=self._location_mode,
                               loop=self._loop,
                               **kwargs)
async def test_no_retry_on_201(http_request):
    class MockTransport(AsyncHttpTransport):
        def __init__(self):
            self._count = 0

        async def __aexit__(self, exc_type, exc_val, exc_tb):
            pass

        async def close(self):
            pass

        async def open(self):
            pass

        async def send(
                self, request,
                **kwargs):  # type: (PipelineRequest, Any) -> PipelineResponse
            self._count += 1
            response = HttpResponse(request, None)
            response.status_code = 201
            headers = {"Retry-After": "1"}
            response.headers = headers
            return response

    http_request = http_request('GET', 'http://localhost/')
    http_retry = AsyncRetryPolicy(retry_total=1)
    transport = MockTransport()
    pipeline = AsyncPipeline(transport, [http_retry])
    await pipeline.run(http_request)
    assert transport._count == 1
    def get_directory_client(self, directory_path=None):
        # type: (Optional[str]) -> ShareDirectoryClient
        """Get a client to interact with the specified directory.
        The directory need not already exist.

        :param str directory_path:
            Path to the specified directory.
        :returns: A Directory Client.
        :rtype: ~azure.storage.fileshare.aio.ShareDirectoryClient
        """
        _pipeline = AsyncPipeline(
            transport=AsyncTransportWrapper(
                self._pipeline._transport
            ),  # pylint: disable = protected-access
            policies=self._pipeline.
            _impl_policies  # pylint: disable = protected-access
        )

        return ShareDirectoryClient(self.url,
                                    share_name=self.share_name,
                                    directory_path=directory_path or "",
                                    snapshot=self.snapshot,
                                    credential=self.credential,
                                    api_version=self.api_version,
                                    _hosts=self._hosts,
                                    _configuration=self._config,
                                    _pipeline=_pipeline,
                                    _location_mode=self._location_mode,
                                    loop=self._loop)
Example #23
0
 async def do():
     request = HttpRequest("GET",
                           "http://localhost:{}/basic/string".format(port))
     policies = [UserAgentPolicy("myusergant"), AsyncRedirectPolicy()]
     async with AsyncPipeline(TrioRequestsTransport(),
                              policies=policies) as pipeline:
         return await pipeline.run(request)
    def get_file_client(self, file_path):
        # type: (str) -> ShareFileClient
        """Get a client to interact with the specified file.
        The file need not already exist.

        :param str file_path:
            Path to the specified file.
        :returns: A File Client.
        :rtype: ~azure.storage.fileshare.aio.ShareFileClient
        """
        _pipeline = AsyncPipeline(
            transport=AsyncTransportWrapper(
                self._pipeline._transport
            ),  # pylint: disable = protected-access
            policies=self._pipeline.
            _impl_policies  # pylint: disable = protected-access
        )

        return ShareFileClient(self.url,
                               share_name=self.share_name,
                               file_path=file_path,
                               snapshot=self.snapshot,
                               credential=self.credential,
                               api_version=self.api_version,
                               _hosts=self._hosts,
                               _configuration=self._config,
                               _pipeline=_pipeline,
                               _location_mode=self._location_mode,
                               loop=self._loop)
 def _build_pipeline(self, **kwargs):  # pylint: disable=no-self-use
     transport = kwargs.get('transport')
     policies = kwargs.get('policies')
     credential_policy = \
         AsyncServiceBusSharedKeyCredentialPolicy(self._endpoint, self._credential, "Authorization") \
         if isinstance(self._credential, ServiceBusSharedKeyCredential) \
         else AsyncBearerTokenCredentialPolicy(self._credential, JWT_TOKEN_SCOPE)
     if policies is None:  # [] is a valid policy list
         policies = [
             RequestIdPolicy(**kwargs),
             self._config.headers_policy,
             self._config.user_agent_policy,
             self._config.proxy_policy,
             ContentDecodePolicy(**kwargs),
             ServiceBusXMLWorkaroundPolicy(),
             self._config.redirect_policy,
             self._config.retry_policy,
             credential_policy,
             self._config.logging_policy,
             DistributedTracingPolicy(**kwargs),
             HttpLoggingPolicy(**kwargs),
         ]
     if not transport:
         transport = AioHttpTransport(**kwargs)
     return AsyncPipeline(transport, policies)
Example #26
0
    def get_repository(self, repository_name: str,
                       **kwargs: Any) -> ContainerRepository:
        """Get a repository client

        :param str repository_name: The repository to create a client for
        :returns: :class:`~azure.containerregistry.aio.ContainerRepository`

        Example

        .. code-block:: python

            from azure.containerregistry.aio import ContainerRepositoryClient
            from azure.identity.aio import DefaultAzureCredential

            account_url = os.environ["CONTAINERREGISTRY_ENDPOINT"]
            client = ContainerRegistryClient(account_url, DefaultAzureCredential())
            repository_client = client.get_repository_client("my_repository")
        """
        _pipeline = AsyncPipeline(
            transport=AsyncTransportWrapper(
                self._client._client._pipeline._transport  # pylint: disable=protected-access
            ),
            policies=self._client._client._pipeline._impl_policies,  # pylint: disable=protected-access
        )
        return ContainerRepository(self._endpoint,
                                   repository_name,
                                   credential=self._credential,
                                   pipeline=_pipeline,
                                   **kwargs)
Example #27
0
    async def test_with_challenge(challenge, expected_scope):
        expected_token = "expected_token"

        class Requests:
            count = 0

        async def send(request):
            Requests.count += 1
            if Requests.count == 1:
                # first request should be unauthorized and have no content
                assert not request.body
                assert request.headers["Content-Length"] == "0"
                return challenge
            elif Requests.count == 2:
                # second request should be authorized according to challenge and have the expected content
                assert request.headers["Content-Length"]
                assert request.body == expected_content
                assert expected_token in request.headers["Authorization"]
                return Mock(status_code=200)
            raise ValueError("unexpected request")

        async def get_token(*scopes):
            assert len(scopes) == 1
            assert scopes[0] == expected_scope
            return AccessToken(expected_token, 0)

        credential = Mock(get_token=Mock(wraps=get_token))
        pipeline = AsyncPipeline(
            policies=[AsyncChallengeAuthPolicy(credential=credential)],
            transport=Mock(send=send))
        request = HttpRequest("POST", get_random_url())
        request.set_bytes_body(expected_content)
        await pipeline.run(request)

        assert credential.get_token.call_count == 1
async def test_enforces_tls():
    url = "http://not.secure"
    HttpChallengeCache.set_challenge_for_url(url, HttpChallenge(url, "Bearer authorization=_, resource=_"))

    credential = Mock()
    pipeline = AsyncPipeline(transport=Mock(), policies=[AsyncChallengeAuthPolicy(credential)])
    with pytest.raises(ServiceRequestError):
        await pipeline.run(HttpRequest("GET", url))
async def test_basic_options_aiohttp():

    request = HttpRequest("OPTIONS", "https://httpbin.org")
    async with AsyncPipeline(AioHttpTransport(), policies=[]) as pipeline:
        response = await pipeline.run(request)

    assert pipeline._transport.session is None
    assert isinstance(response.http_response.status_code, int)
Example #30
0
async def test_basic_options_aiohttp(port):

    request = HttpRequest("OPTIONS", "http://localhost:{}/basic/string".format(port))
    async with AsyncPipeline(AioHttpTransport(), policies=[]) as pipeline:
        response = await pipeline.run(request)

    assert pipeline._transport.session is None
    assert isinstance(response.http_response.status_code, int)