Example #1
0
 def _create_pipeline(self, credential, **kwargs):
     # type: (Any, **Any) -> Tuple[Configuration, Pipeline]
     self._credential_policy = None
     if hasattr(credential, 'get_token'):
         self._credential_policy = AsyncBearerTokenCredentialPolicy(
             credential, STORAGE_OAUTH_SCOPE)
     elif isinstance(credential, SharedKeyCredentialPolicy):
         self._credential_policy = credential
     elif isinstance(credential, AzureSasCredential):
         self._credential_policy = AzureSasCredentialPolicy(credential)
     elif credential is not None:
         raise TypeError("Unsupported credential: {}".format(credential))
     config = kwargs.get('_configuration') or create_configuration(**kwargs)
     if kwargs.get('_pipeline'):
         return config, kwargs['_pipeline']
     config.transport = kwargs.get('transport')  # type: ignore
     kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT)
     kwargs.setdefault("read_timeout", READ_TIMEOUT)
     if not config.transport:
         try:
             from azure.core.pipeline.transport import AioHttpTransport
         except ImportError:
             raise ImportError(
                 "Unable to create async transport. Please check aiohttp is installed."
             )
         config.transport = AioHttpTransport(**kwargs)
     policies = [
         QueueMessagePolicy(),
         config.headers_policy,
         config.proxy_policy,
         config.user_agent_policy,
         StorageContentValidation(),
         StorageRequestHook(**kwargs),
         self._credential_policy,
         ContentDecodePolicy(response_encoding="utf-8"),
         AsyncRedirectPolicy(**kwargs),
         StorageHosts(hosts=self._hosts, **kwargs),  # type: ignore
         config.retry_policy,
         config.logging_policy,
         AsyncStorageResponseHook(**kwargs),
         DistributedTracingPolicy(**kwargs),
         HttpLoggingPolicy(**kwargs),
     ]
     if kwargs.get("_additional_pipeline_policies"):
         policies = policies + kwargs.get("_additional_pipeline_policies")
     return config, AsyncPipeline(config.transport, policies=policies)
Example #2
0
 def __init__(self,
              config: "Optional[Configuration]" = None,
              policies: "Optional[Iterable[HTTPPolicy]]" = None,
              transport: "Optional[AsyncHttpTransport]" = None,
              **kwargs: "Any") -> None:
     config = config or self._create_config(**kwargs)
     policies = policies or [
         ContentDecodePolicy(),
         config.retry_policy,
         config.logging_policy,
         DistributedTracingPolicy(**kwargs),
         HttpLoggingPolicy(**kwargs),
     ]
     if not transport:
         transport = AioHttpTransport(**kwargs)
     self._pipeline = AsyncPipeline(transport=transport, policies=policies)
     super().__init__(**kwargs)
Example #3
0
async def test_does_not_sleep_after_timeout(transport_error,
                                            expected_timeout_error):
    # With default settings policy will sleep twice before exhausting its retries: 1.6s, 3.2s.
    # It should not sleep the second time when given timeout=1
    timeout = 1

    transport = Mock(
        spec=AsyncHttpTransport,
        send=Mock(side_effect=transport_error("oops")),
        sleep=Mock(wraps=asyncio.sleep),
    )
    pipeline = AsyncPipeline(transport, [AsyncRetryPolicy(timeout=timeout)])

    with pytest.raises(expected_timeout_error):
        await pipeline.run(HttpRequest("GET", "http://localhost/"))

    assert transport.sleep.call_count == 1
async def test_token_expiration():
    """policy should not use a cached token which has expired"""

    url = get_random_url()

    expires_on = time.time() + 3600
    first_token = "*"
    second_token = "**"

    token = AccessToken(first_token, expires_on)

    async def get_token(*_, **__):
        return token

    credential = Mock(get_token=Mock(wraps=get_token))
    transport = async_validating_transport(
        requests=[
            Request(),
            Request(
                required_headers={"Authorization": "Bearer " + first_token}),
            Request(
                required_headers={"Authorization": "Bearer " + first_token}),
            Request(
                required_headers={"Authorization": "Bearer " + second_token}),
        ],
        responses=[
            mock_response(
                status_code=401,
                headers={
                    "WWW-Authenticate":
                    'Bearer authorization="{}", resource=foo'.format(url)
                })
        ] + [mock_response()] * 3,
    )
    pipeline = AsyncPipeline(
        policies=[AsyncChallengeAuthPolicy(credential=credential)],
        transport=transport)

    for _ in range(2):
        await pipeline.run(HttpRequest("GET", url))
        assert credential.get_token.call_count == 1

    token = AccessToken(second_token, time.time() + 3600)
    with patch("time.time", lambda: expires_on):
        await pipeline.run(HttpRequest("GET", url))
    assert credential.get_token.call_count == 2
    def get_share_client(self, share, snapshot=None):
        # type: (Union[ShareProperties, str],Optional[Union[Dict[str, Any], str]]) -> ShareClient
        """Get a client to interact with the specified share.
        The share need not already exist.

        :param share:
            The share. This can either be the name of the share,
            or an instance of ShareProperties.
        :type share: str or ~azure.storage.fileshare.ShareProperties
        :param str snapshot:
            An optional share snapshot on which to operate. This can be the snapshot ID string
            or the response returned from :func:`create_snapshot`.
        :returns: A ShareClient.
        :rtype: ~azure.storage.fileshare.aio.ShareClient

        .. admonition:: Example:

            .. literalinclude:: ../samples/file_samples_service_async.py
                :start-after: [START get_share_client]
                :end-before: [END get_share_client]
                :language: python
                :dedent: 8
                :caption: Gets the share client.
        """
        try:
            share_name = share.name
        except AttributeError:
            share_name = share

        _pipeline = AsyncPipeline(
            transport=AsyncTransportWrapper(
                self._pipeline._transport
            ),  # pylint: disable = protected-access
            policies=self._pipeline.
            _impl_policies  # pylint: disable = protected-access
        )
        return ShareClient(self.url,
                           share_name=share_name,
                           snapshot=snapshot,
                           credential=self.credential,
                           api_version=self.api_version,
                           _hosts=self._hosts,
                           _configuration=self._config,
                           _pipeline=_pipeline,
                           _location_mode=self._location_mode,
                           loop=self._loop)
Example #6
0
async def test_example_async_pipeline():
    # [START build_async_pipeline]
    from azure.core.pipeline import AsyncPipeline
    from azure.core.pipeline.policies import AsyncRedirectPolicy, UserAgentPolicy
    from azure.core.pipeline.transport import AioHttpTransport, HttpRequest

    # example: create request and policies
    request = HttpRequest("GET", "https://bing.com")
    policies = [UserAgentPolicy("myuseragent"), AsyncRedirectPolicy()]

    # run the pipeline
    async with AsyncPipeline(transport=AioHttpTransport(),
                             policies=policies) as pipeline:
        response = await pipeline.run(request)
    # [END build_async_pipeline]
    assert pipeline._transport.session is None
    assert response.http_response.status_code == 200
async def test_basic_aiohttp_separate_session():

    session = aiohttp.ClientSession()
    request = HttpRequest("GET", "https://bing.com")
    policies = [
        UserAgentPolicy("myusergant"),
        AsyncRedirectPolicy()
    ]
    transport = AioHttpTransport(session=session, session_owner=False)
    async with AsyncPipeline(transport, policies=policies) as pipeline:
        response = await pipeline.run(request)

    assert transport.session
    assert isinstance(response.http_response.status_code, int)
    await transport.close()
    assert transport.session
    await transport.session.close()
Example #8
0
    def get_container_client(self, container):
        # type: (Union[ContainerProperties, str]) -> ContainerClient
        """Get a client to interact with the specified container.

        The container need not already exist.

        :param container:
            The container. This can either be the name of the container,
            or an instance of ContainerProperties.
        :type container: str or ~azure.storage.blob.ContainerProperties
        :returns: A ContainerClient.
        :rtype: ~azure.storage.blob.aio.ContainerClient

        .. admonition:: Example:

            .. literalinclude:: ../samples/blob_samples_service_async.py
                :start-after: [START bsc_get_container_client]
                :end-before: [END bsc_get_container_client]
                :language: python
                :dedent: 12
                :caption: Getting the container client to interact with a specific container.
        """
        try:
            container_name = container.name
        except AttributeError:
            container_name = container
        _pipeline = AsyncPipeline(
            transport=AsyncTransportWrapper(
                self._pipeline._transport
            ),  # pylint: disable = protected-access
            policies=self._pipeline.
            _impl_policies  # pylint: disable = protected-access
        )
        return ContainerClient(
            self.url,
            container_name=container_name,
            credential=self.credential,
            api_version=self.api_version,
            _configuration=self._config,
            _pipeline=_pipeline,
            _location_mode=self._location_mode,
            _hosts=self._hosts,
            require_encryption=self.require_encryption,
            key_encryption_key=self.key_encryption_key,
            key_resolver_function=self.key_resolver_function,
            loop=self._loop)
Example #9
0
    def _build_pipeline(config: Configuration, transport: HttpTransport,
                        **kwargs: Any) -> AsyncPipeline:
        policies = [
            config.headers_policy,
            config.user_agent_policy,
            config.proxy_policy,
            config.redirect_policy,
            config.retry_policy,
            config.authentication_policy,
            config.logging_policy,
            DistributedTracingPolicy(),
        ]

        if transport is None:
            transport = AsyncioRequestsTransport(**kwargs)

        return AsyncPipeline(transport, policies=policies)
 def __init__(self,
              auth_url: str,
              config: Optional[Configuration] = None,
              policies: Optional[Iterable[HTTPPolicy]] = None,
              transport: Optional[AsyncHttpTransport] = None,
              **kwargs: Mapping[str, Any]) -> None:
     config = config or self._create_config(**kwargs)
     policies = policies or [
         ContentDecodePolicy(),
         config.retry_policy,
         config.logging_policy,
         DistributedTracingPolicy(),
     ]
     if not transport:
         transport = AsyncioRequestsTransport(**kwargs)
     self._pipeline = AsyncPipeline(transport=transport, policies=policies)
     super(AsyncAuthnClient, self).__init__(auth_url, **kwargs)
    def create_client(send_cb):
        class TestHttpTransport(AsyncHttpTransport):
            async def open(self):
                pass

            async def close(self):
                pass

            async def __aexit__(self, *args, **kwargs):
                pass

            async def send(self, request, **kwargs):
                return await send_cb(request, **kwargs)

        return AsyncPipelineClient(
            'http://example.org/',
            pipeline=AsyncPipeline(transport=TestHttpTransport()))
Example #12
0
async def test_multipart_send_with_one_changeset():
    transport = MockAsyncHttpTransport()
    requests = [
        HttpRequest("DELETE", "/container0/blob0"),
        HttpRequest("DELETE", "/container1/blob1")
    ]
    changeset = HttpRequest("", "")
    changeset.set_multipart_mixed(
        *requests,
        boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525"
    )

    request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch")
    request.set_multipart_mixed(
        changeset,
        boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525"
    )

    async with AsyncPipeline(transport) as pipeline:
        await pipeline.run(request)

    assert request.body == (
        b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n'
        b'Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n'
        b'\r\n'
        b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n'
        b'Content-Type: application/http\r\n'
        b'Content-Transfer-Encoding: binary\r\n'
        b'Content-ID: 0\r\n'
        b'\r\n'
        b'DELETE /container0/blob0 HTTP/1.1\r\n'
        b'\r\n'
        b'\r\n'
        b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n'
        b'Content-Type: application/http\r\n'
        b'Content-Transfer-Encoding: binary\r\n'
        b'Content-ID: 1\r\n'
        b'\r\n'
        b'DELETE /container1/blob1 HTTP/1.1\r\n'
        b'\r\n'
        b'\r\n'
        b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n'
        b'\r\n'
        b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n'
    )
async def test_multipart_send():
    transport = MockAsyncHttpTransport()

    class RequestPolicy(object):
        async def on_request(self, request):
            # type: (PipelineRequest) -> None
            request.http_request.headers[
                'x-ms-date'] = 'Thu, 14 Jun 2018 16:46:54 GMT'

    req0 = HttpRequest("DELETE", "/container0/blob0")
    req1 = HttpRequest("DELETE", "/container1/blob1")

    request = HttpRequest("POST",
                          "http://account.blob.core.windows.net/?comp=batch")
    request.set_multipart_mixed(
        req0,
        req1,
        policies=[RequestPolicy()],
        boundary=
        "batch_357de4f7-6d0b-4e02-8cd2-6361411a9525"  # Fix it so test are deterministic
    )

    async with AsyncPipeline(transport) as pipeline:
        await pipeline.run(request)

    assert request.body == (
        b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n'
        b'Content-Type: application/http\r\n'
        b'Content-Transfer-Encoding: binary\r\n'
        b'Content-ID: 0\r\n'
        b'\r\n'
        b'DELETE /container0/blob0 HTTP/1.1\r\n'
        b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n'
        b'\r\n'
        b'\r\n'
        b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n'
        b'Content-Type: application/http\r\n'
        b'Content-Transfer-Encoding: binary\r\n'
        b'Content-ID: 1\r\n'
        b'\r\n'
        b'DELETE /container1/blob1 HTTP/1.1\r\n'
        b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n'
        b'\r\n'
        b'\r\n'
        b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n')
Example #14
0
async def test_response_streaming_error_behavior():
    # Test to reproduce https://github.com/Azure/azure-sdk-for-python/issues/16723
    block_size = 103
    total_response_size = 500
    req_response = requests.Response()
    req_request = requests.Request()

    class FakeStreamWithConnectionError:
        # fake object for urllib3.response.HTTPResponse

        def stream(self, chunk_size, decode_content=False):
            assert chunk_size == block_size
            left = total_response_size
            while left > 0:
                if left <= block_size:
                    raise requests.exceptions.ConnectionError()
                data = b"X" * min(chunk_size, left)
                left -= len(data)
                yield data

        def close(self):
            pass

    req_response.raw = FakeStreamWithConnectionError()

    response = AsyncioRequestsTransportResponse(
        req_request,
        req_response,
        block_size,
    )

    async def mock_run(self, *args, **kwargs):
        return PipelineResponse(
            None,
            requests.Response(),
            None,
        )

    transport = AioHttpTransport()
    pipeline = AsyncPipeline(transport)
    pipeline.run = mock_run
    downloader = response.stream_download(pipeline)
    with pytest.raises(requests.exceptions.ConnectionError):
        while True:
            await downloader.__anext__()
Example #15
0
    def _build_pipeline(config: Configuration, transport: AsyncHttpTransport, **kwargs: "**Any") -> AsyncPipeline:
        policies = [
            config.headers_policy,
            config.user_agent_policy,
            config.proxy_policy,
            config.redirect_policy,
            config.retry_policy,
            config.authentication_policy,
            config.logging_policy,
            DistributedTracingPolicy(),
        ]

        if transport is None:
            from azure.core.pipeline.transport import AioHttpTransport

            transport = AioHttpTransport(**kwargs)

        return AsyncPipeline(transport, policies=policies)
Example #16
0
async def test_preserves_enforce_https_opt_out():
    """The policy should use request context to preserve an opt out from https enforcement"""
    class ContextValidator(SansIOHTTPPolicy):
        def on_request(self, request):
            assert "enforce_https" in request.context, "'enforce_https' is not in the request's context"

    get_token = get_completed_future(AccessToken("***", 42))
    credential = Mock(get_token=lambda *_, **__: get_token)
    policies = [
        AsyncBearerTokenCredentialPolicy(credential, "scope"),
        ContextValidator()
    ]
    pipeline = AsyncPipeline(
        transport=Mock(send=lambda *_, **__: get_completed_future()),
        policies=policies)

    await pipeline.run(HttpRequest("GET", "http://not.secure"),
                       enforce_https=False)
    def get_form_recognizer_client(self,
                                   **kwargs: Any) -> FormRecognizerClient:
        """Get an instance of a FormRecognizerClient from FormTrainingClient.

        :rtype: ~azure.ai.formrecognizer.aio.FormRecognizerClient
        :return: A FormRecognizerClient
        """
        _pipeline = AsyncPipeline(transport=AsyncTransportWrapper(
            self._client._client._pipeline._transport),
                                  policies=self._client._client._pipeline.
                                  _impl_policies)  # type: AsyncPipeline
        client = FormRecognizerClient(endpoint=self._endpoint,
                                      credential=self._credential,
                                      pipeline=_pipeline,
                                      **kwargs)
        # need to share config, but can't pass as a keyword into client
        client._client._config = self._client._client._config
        return client
    def __init__(self,
                 config: "Optional[Configuration]" = None,
                 policies: "Optional[Iterable[AsyncHTTPPolicy]]" = None,
                 transport: "Optional[AsyncHttpTransport]" = None,
                 **kwargs: "Any") -> None:

        config = config or self._create_config(**kwargs)
        policies = policies or [
            config.retry_policy,
            config.logging_policy,
            DistributedTracingPolicy(**kwargs),
            HttpLoggingPolicy(**kwargs),
        ]
        self._transport = transport or AioHttpTransport(configuration=config)
        atexit.register(
            self._close_transport_session)  # prevent aiohttp warnings
        self._pipeline = AsyncPipeline(transport=self._transport,
                                       policies=policies)
Example #19
0
async def test_connection_error_416():
    class MockTransport(AsyncHttpTransport):
        def __init__(self):
            self._count = 0

        async def __aexit__(self, exc_type, exc_val, exc_tb):
            pass

        async def close(self):
            pass

        async def open(self):
            pass

        async def send(self, request, **kwargs):
            request = HttpRequest('GET', 'http://127.0.0.1/')
            response = AsyncHttpResponse(request, None)
            response.status_code = 416
            return response

    class MockContent():
        async def read(self, block_size):
            raise ConnectionError

    class MockInternalResponse():
        def __init__(self):
            self.headers = {}
            self.content = MockContent()

        async def close(self):
            pass

    class AsyncMock(mock.MagicMock):
        async def __call__(self, *args, **kwargs):
            return super(AsyncMock, self).__call__(*args, **kwargs)

    http_request = HttpRequest('GET', 'http://127.0.0.1/')
    pipeline = AsyncPipeline(MockTransport())
    http_response = AsyncHttpResponse(http_request, None)
    http_response.internal_response = MockInternalResponse()
    stream = AioHttpStreamDownloadGenerator(pipeline, http_response)
    with mock.patch('asyncio.sleep', new_callable=AsyncMock):
        with pytest.raises(ConnectionError):
            await stream.__anext__()
    def __init__(
        self,
        config: "Optional[Configuration]" = None,
        policies: "Optional[Iterable[AsyncHTTPPolicy]]" = None,
        transport: "Optional[AsyncHttpTransport]" = None,
        **kwargs: "Any"
    ) -> None:

        config = config or self._create_config(**kwargs)
        policies = policies or [
            config.user_agent_policy,
            config.proxy_policy,
            config.retry_policy,
            config.logging_policy,
            DistributedTracingPolicy(**kwargs),
            HttpLoggingPolicy(**kwargs),
        ]
        self._transport = transport or AioHttpTransport(configuration=config)
        self._pipeline = AsyncPipeline(transport=self._transport, policies=policies)
Example #21
0
async def test_context_unmodified_by_default():
    """When no options for the policy accompany a request, the policy shouldn't add anything to the request context"""
    class ContextValidator(SansIOHTTPPolicy):
        def on_request(self, request):
            assert not any(
                request.context
            ), "the policy shouldn't add to the request's context"

    get_token = get_completed_future(AccessToken("***", 42))
    credential = Mock(get_token=lambda *_, **__: get_token)
    policies = [
        AsyncBearerTokenCredentialPolicy(credential, "scope"),
        ContextValidator()
    ]
    pipeline = AsyncPipeline(
        transport=Mock(send=lambda *_, **__: get_completed_future()),
        policies=policies)

    await pipeline.run(HttpRequest("GET", "https://secure"))
    def get_directory_client(self, directory_path=None):
        # type: (Optional[str]) -> ShareDirectoryClient
        """Get a client to interact with the specified directory.
        The directory need not already exist.

        :param str directory_path:
            Path to the specified directory.
        :returns: A Directory Client.
        :rtype: ~azure.storage.fileshare.aio.ShareDirectoryClient
        """
        _pipeline = AsyncPipeline(
            transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
            policies=self._pipeline._impl_policies # pylint: disable = protected-access
        )

        return ShareDirectoryClient(
            self.url, share_name=self.share_name, directory_path=directory_path or "", snapshot=self.snapshot,
            credential=self.credential, api_version=self.api_version, _hosts=self._hosts, _configuration=self._config,
            _pipeline=_pipeline, _location_mode=self._location_mode, loop=self._loop)
Example #23
0
async def test_timeout_defaults():
    """When "timeout" is not set, the policy should not override the transport's timeout configuration"""
    async def send(request, **kwargs):
        for arg in ("connection_timeout", "read_timeout"):
            assert arg not in kwargs, "policy should defer to transport configuration when not given a timeout"
        response = HttpResponse(request, None)
        response.status_code = 200
        return response

    transport = Mock(
        spec_set=AsyncHttpTransport,
        send=Mock(wraps=send),
        sleep=Mock(side_effect=Exception(
            "policy should not sleep: its first send succeeded")),
    )
    pipeline = AsyncPipeline(transport, [AsyncRetryPolicy()])

    await pipeline.run(HttpRequest("GET", "http://127.0.0.1/"))
    assert transport.send.call_count == 1, "policy should not retry: its first send succeeded"
Example #24
0
async def test_retry_timeout():
    timeout = 1

    def send(request, **kwargs):
        assert kwargs[
            "connection_timeout"] <= timeout, "policy should set connection_timeout not to exceed timeout"
        raise ServiceResponseError("oops")

    transport = Mock(
        spec=AsyncHttpTransport,
        send=Mock(wraps=send),
        connection_config=ConnectionConfiguration(connection_timeout=timeout *
                                                  2),
        sleep=asyncio.sleep,
    )
    pipeline = AsyncPipeline(transport, [AsyncRetryPolicy(timeout=timeout)])

    with pytest.raises(ServiceResponseTimeoutError):
        await pipeline.run(HttpRequest("GET", "http://127.0.0.1/"))
    def get_file_client(self, file_path):
        # type: (str) -> ShareFileClient
        """Get a client to interact with the specified file.
        The file need not already exist.

        :param str file_path:
            Path to the specified file.
        :returns: A File Client.
        :rtype: ~azure.storage.fileshare.aio.ShareFileClient
        """
        _pipeline = AsyncPipeline(
            transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
            policies=self._pipeline._impl_policies # pylint: disable = protected-access
        )

        return ShareFileClient(
            self.url, share_name=self.share_name, file_path=file_path, snapshot=self.snapshot,
            credential=self.credential, api_version=self.api_version, _hosts=self._hosts, _configuration=self._config,
            _pipeline=_pipeline, _location_mode=self._location_mode, loop=self._loop)
    async def _undelete_path(self, deleted_path_name, deletion_id, **kwargs):
        # type: (str, str, **Any) -> Union[DataLakeDirectoryClient, DataLakeFileClient]
        """Restores soft-deleted path.

        Operation will only be successful if used within the specified number of days
        set in the delete retention policy.

        .. versionadded:: 12.4.0
            This operation was introduced in API version '2020-06-12'.

        :param str deleted_path_name:
            Specifies the name of the deleted container to restore.
        :param str deletion_id:
            Specifies the version of the deleted container to restore.
        :keyword int timeout:
            The timeout parameter is expressed in seconds.
        :rtype: ~azure.storage.file.datalake.aio.DataLakeDirectoryClient
                or azure.storage.file.datalake.aio.DataLakeFileClient
        """
        _, url, undelete_source = self._undelete_path_options(
            deleted_path_name, deletion_id)

        pipeline = AsyncPipeline(
            transport=AsyncTransportWrapper(
                self._pipeline._transport
            ),  # pylint: disable = protected-access
            policies=self._pipeline.
            _impl_policies  # pylint: disable = protected-access
        )
        path_client = AzureDataLakeStorageRESTAPI(
            url,
            filesystem=self.file_system_name,
            path=deleted_path_name,
            pipeline=pipeline)
        try:
            is_file = await path_client.path.undelete(
                undelete_source=undelete_source, cls=is_file_path, **kwargs)
            if is_file:
                return self.get_file_client(deleted_path_name)
            return self.get_directory_client(deleted_path_name)
        except HttpResponseError as error:
            process_storage_error(error)
Example #27
0
async def test_retry_seekable_file():
    class MockTransport(AsyncHttpTransport):
        def __init__(self):
            self._first = True
        async def __aexit__(self, exc_type, exc_val, exc_tb):
            pass
        async def close(self):
            pass
        async def open(self):
            pass

        async def send(self, request, **kwargs):  # type: (PipelineRequest, Any) -> PipelineResponse
            if self._first:
                self._first = False
                for value in request.files.values():
                    name, body = value[0], value[1]
                    if name and body and hasattr(body, 'read'):
                        body.seek(0,2)
                        raise AzureError('fail on first')
            for value in request.files.values():
                name, body = value[0], value[1]
                if name and body and hasattr(body, 'read'):
                    position = body.tell()
                    assert not position
                    return HttpResponse(request, None)

    file = tempfile.NamedTemporaryFile(delete=False)
    file.write(b'Lots of dataaaa')
    file.close()
    http_request = HttpRequest('GET', 'http://127.0.0.1/')
    headers = {'Content-Type': "multipart/form-data"}
    http_request.headers = headers
    with open(file.name, 'rb') as f:
        form_data_content = {
            'fileContent': f,
            'fileName': f.name,
        }
        http_request.set_formdata_body(form_data_content)
        http_retry = AsyncRetryPolicy(retry_total=1)
        pipeline = AsyncPipeline(MockTransport(), [http_retry])
        await pipeline.run(http_request)
    os.unlink(f.name)
async def test_policy():
    # ensure the test starts with an empty cache
    HttpChallengeCache.clear()

    expected_scope = "https://challenge.resource/.default"
    expected_token = "expected_token"
    challenge = Mock(
        status_code=401,
        headers={
            "WWW-Authenticate":
            'Bearer authorization="https://login.authority.net/tenant", resource={}'
            .format(expected_scope)
        },
    )
    success = Mock(status_code=200)
    data = {"spam": "eggs"}

    responses = (r for r in (challenge, success))

    async def send(request):
        response = next(responses)
        if response is challenge:
            # this is the first request
            assert not request.body
            assert request.headers["Content-Length"] == "0"
        elif response is success:
            # this is the second request
            assert request.body == data
            assert expected_token in request.headers["Authorization"]
        return response

    async def get_token(*scopes):
        print("get token")
        assert len(scopes) is 1
        assert scopes[0] == expected_scope
        return AccessToken(expected_token, 0)

    credential = Mock(get_token=get_token)
    pipeline = AsyncPipeline(
        policies=[AsyncChallengeAuthPolicy(credential=credential)],
        transport=Mock(send=send))
    await pipeline.run(HttpRequest("POST", "https://azure.service", data=data))
    def _build_pipeline(self,
                        config: Configuration = None,
                        policies: "Optional[List[Policy]]" = None,
                        transport: "Optional[AsyncHttpTransport]" = None,
                        **kwargs: "Any") -> AsyncPipeline:
        config = config or _create_config(**kwargs)
        policies = policies or [
            config.user_agent_policy,
            config.proxy_policy,
            config.retry_policy,
            config.logging_policy,
            DistributedTracingPolicy(**kwargs),
            HttpLoggingPolicy(**kwargs),
        ]
        if not transport:
            from azure.core.pipeline.transport import AioHttpTransport

            transport = AioHttpTransport(configuration=config)

        return AsyncPipeline(transport=transport, policies=policies)
Example #30
0
    def get_artifact(self, repository_name: str, tag_or_digest: str,
                     **kwargs: Dict[str, Any]) -> RegistryArtifact:
        """Get a Registry Artifact object

        :param str repository_name: Name of the repository
        :param str tag_or_digest: The tag or digest of the artifact
        :returns: :class:`~azure.containerregistry.RegistryArtifact`
        :raises: None
        """
        _pipeline = AsyncPipeline(
            transport=AsyncTransportWrapper(
                self._client._client._pipeline._transport),  # pylint: disable=protected-access
            policies=self._client._client._pipeline._impl_policies,  # pylint: disable=protected-access
        )
        return RegistryArtifact(self._endpoint,
                                repository_name,
                                tag_or_digest,
                                self._credential,
                                pipeline=_pipeline,
                                **kwargs)