def test_retry_on_429(http_request, http_response):
    class MockTransport(HttpTransport):
        def __init__(self):
            self._count = 0

        def __exit__(self, exc_type, exc_val, exc_tb):
            pass

        def close(self):
            pass

        def open(self):
            pass

        def send(self, request,
                 **kwargs):  # type: (PipelineRequest, Any) -> PipelineResponse
            self._count += 1
            response = create_http_response(http_response, request, None)
            response.status_code = 429
            return response

    http_request = http_request('GET', 'http://localhost/')
    http_retry = RetryPolicy(retry_total=1)
    transport = MockTransport()
    pipeline = Pipeline(transport, [http_retry])
    pipeline.run(http_request)
    assert transport._count == 2
예제 #2
0
def test_retry_timeout():
    class MockTransport(HttpTransport):
        def __init__(self):
            self.count = 0

        def __exit__(self, exc_type, exc_val, exc_tb):
            pass

        def close(self):
            pass

        def open(self):
            pass

        def send(self, request,
                 **kwargs):  # type: (PipelineRequest, Any) -> PipelineResponse
            self.count += 1
            if self.count > 2:
                assert self.count <= 2
            time.sleep(0.5)
            raise ServiceResponseError('timeout')

    http_request = HttpRequest('GET', 'http://127.0.0.1/')
    headers = {'Content-Type': "multipart/form-data"}
    http_request.headers = headers
    http_retry = RetryPolicy(retry_total=10, timeout=1)
    pipeline = Pipeline(MockTransport(), [http_retry])
    with pytest.raises(ServiceResponseTimeoutError):
        pipeline.run(http_request)
예제 #3
0
class BearerTokenPolicyTest(PerfStressTest):
    def __init__(self, arguments):
        super().__init__(arguments)

        token = AccessToken("**", int(time.time() + 3600))

        self.request = HttpRequest("GET", "https://localhost")

        credential = Mock(get_token=Mock(return_value=token))
        self.pipeline = Pipeline(
            transport=Mock(),
            policies=[BearerTokenCredentialPolicy(credential=credential)])

        completed_future = asyncio.Future()
        completed_future.set_result(token)
        async_credential = Mock(get_token=Mock(return_value=completed_future))

        # returning a token is okay because the policy does nothing with the transport's response
        async_transport = Mock(send=Mock(return_value=completed_future))
        self.async_pipeline = AsyncPipeline(
            async_transport,
            policies=[
                AsyncBearerTokenCredentialPolicy(credential=async_credential)
            ])

    def run_sync(self):
        self.pipeline.run(self.request)

    async def run_async(self):
        await self.async_pipeline.run(self.request)
예제 #4
0
def test_enforces_tls():
    url = "http://not.secure"
    HttpChallengeCache.set_challenge_for_url(url, HttpChallenge(url, "Bearer authorization=_, resource=_"))

    credential = Mock()
    pipeline = Pipeline(transport=Mock(), policies=[ChallengeAuthPolicy(credential)])
    with pytest.raises(ServiceRequestError):
        pipeline.run(HttpRequest("GET", url))
예제 #5
0
def test_retry_without_http_response():
    class NaughtyPolicy(HTTPPolicy):
        def send(*args):
            raise AzureError('boo')

    policies = [RetryPolicy(), NaughtyPolicy()]
    pipeline = Pipeline(policies=policies, transport=None)
    with pytest.raises(AzureError):
        pipeline.run(HttpRequest('GET', url='https://foo.bar'))
def test_preserves_options_and_headers():
    """After a challenge, the policy should send the original request with its options and headers preserved"""

    url = get_random_url()
    token = "**"

    def get_token(*_, **__):
        return AccessToken(token, 0)

    credential = Mock(get_token=Mock(wraps=get_token))

    transport = validating_transport(
        requests=[Request()] * 2 +
        [Request(required_headers={"Authorization": "Bearer " + token})],
        responses=[
            mock_response(
                status_code=401,
                headers={
                    "WWW-Authenticate":
                    'Bearer authorization="{}", resource=foo'.format(url)
                })
        ] + [mock_response()] * 2,
    )

    key = "foo"
    value = "bar"

    def add(request):
        # add the expected option and header
        request.context.options[key] = value
        request.http_request.headers[key] = value

    adder = Mock(spec_set=SansIOHTTPPolicy,
                 on_request=Mock(wraps=add),
                 on_exception=lambda _: False)

    def verify(request):
        # authorized (non-challenge) requests should have the expected option and header
        if request.http_request.headers.get("Authorization"):
            assert request.context.options.get(
                key
            ) == value, "request option wasn't preserved across challenge"
            assert request.http_request.headers.get(
                key) == value, "headers wasn't preserved across challenge"

    verifier = Mock(spec=SansIOHTTPPolicy, on_request=Mock(wraps=verify))

    challenge_policy = ChallengeAuthPolicy(credential=credential)
    policies = [adder, challenge_policy, verifier]
    pipeline = Pipeline(policies=policies, transport=transport)

    pipeline.run(HttpRequest("GET", url))

    # ensure the mock sans I/O policies were called
    assert adder.on_request.called, "mock policy wasn't invoked"
    assert verifier.on_request.called, "mock policy wasn't invoked"
예제 #7
0
def test_azure_sas_credential_policy(sas, url, expected_url):
    """Tests to see if we can create an AzureSasCredentialPolicy"""
    def verify_authorization(request):
        assert request.url == expected_url

    transport = Mock(send=verify_authorization)
    credential = AzureSasCredential(sas)
    credential_policy = AzureSasCredentialPolicy(credential=credential)
    pipeline = Pipeline(transport=transport, policies=[credential_policy])

    pipeline.run(HttpRequest("GET", url))
예제 #8
0
def test_policy_updates_cache():
    """
    It's possible for the challenge returned for a request to change, e.g. when a vault is moved to a new tenant.
    When the policy receives a 401, it should update the cached challenge for the requested URL, if one exists.
    """

    url = get_random_url()
    first_scope = "https://first-scope"
    first_token = "first-scope-token"
    second_scope = "https://second-scope"
    second_token = "second-scope-token"
    challenge_fmt = 'Bearer authorization="https://login.authority.net/tenant", resource={}'

    # mocking a tenant change:
    # 1. first request -> respond with challenge
    # 2. second request should be authorized according to the challenge
    # 3. third request should match the second (using a cached access token)
    # 4. fourth request should also match the second -> respond with a new challenge
    # 5. fifth request should be authorized according to the new challenge
    # 6. sixth request should match the fifth
    transport = validating_transport(
        requests=(
            Request(url),
            Request(url, required_headers={"Authorization": "Bearer {}".format(first_token)}),
            Request(url, required_headers={"Authorization": "Bearer {}".format(first_token)}),
            Request(url, required_headers={"Authorization": "Bearer {}".format(first_token)}),
            Request(url, required_headers={"Authorization": "Bearer {}".format(second_token)}),
            Request(url, required_headers={"Authorization": "Bearer {}".format(second_token)}),
        ),
        responses=(
            mock_response(status_code=401, headers={"WWW-Authenticate": challenge_fmt.format(first_scope)}),
            mock_response(status_code=200),
            mock_response(status_code=200),
            mock_response(status_code=401, headers={"WWW-Authenticate": challenge_fmt.format(second_scope)}),
            mock_response(status_code=200),
            mock_response(status_code=200),
        ),
    )

    credential = Mock(get_token=Mock(return_value=AccessToken(first_token, time.time() + 3600)))
    pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=transport)

    # policy should complete and cache the first challenge and access token
    for _ in range(2):
        pipeline.run(HttpRequest("GET", url))
        assert credential.get_token.call_count == 1

    # The next request will receive a new challenge. The policy should handle it and update caches.
    credential.get_token.return_value = AccessToken(second_token, time.time() + 3600)
    for _ in range(2):
        pipeline.run(HttpRequest("GET", url))
        assert credential.get_token.call_count == 2
def test_retry_seekable_file(http_request, http_response):
    class MockTransport(HttpTransport):
        def __init__(self):
            self._first = True

        def __exit__(self, exc_type, exc_val, exc_tb):
            pass

        def close(self):
            pass

        def open(self):
            pass

        def send(self, request,
                 **kwargs):  # type: (PipelineRequest, Any) -> PipelineResponse
            if self._first:
                self._first = False
                for value in request.files.values():
                    name, body = value[0], value[1]
                    if name and body and hasattr(body, 'read'):
                        body.seek(0, 2)
                        raise AzureError('fail on first')
            for value in request.files.values():
                name, body = value[0], value[1]
                if name and body and hasattr(body, 'read'):
                    position = body.tell()
                    assert not position
                    response = create_http_response(http_response, request,
                                                    None)
                    response.status_code = 400
                    return response

    file = tempfile.NamedTemporaryFile(delete=False)
    file.write(b'Lots of dataaaa')
    file.close()
    http_request = http_request('GET', 'http://localhost/')
    headers = {'Content-Type': "multipart/form-data"}
    http_request.headers = headers
    with open(file.name, 'rb') as f:
        form_data_content = {
            'fileContent': f,
            'fileName': f.name,
        }
        http_request.set_formdata_body(form_data_content)
        http_retry = RetryPolicy(retry_total=1)
        pipeline = Pipeline(MockTransport(), [http_retry])
        pipeline.run(http_request)
    os.unlink(f.name)
def test_azure_key_credential_policy():
    """Tests to see if we can create an AzureKeyCredentialPolicy"""

    key_header = "api_key"
    api_key = "test_key"

    def verify_authorization_header(request):
        assert request.headers[key_header] == api_key

    transport=Mock(send=verify_authorization_header)
    credential = AzureKeyCredential(api_key)
    credential_policy = AzureKeyCredentialPolicy(credential=credential, name=key_header)
    pipeline = Pipeline(transport=transport, policies=[credential_policy])

    pipeline.run(HttpRequest("GET", "https://test_key_credential"))
def test_policy_updates_cache():
    """
    It's possible for the challenge returned for a request to change, e.g. when a vault is moved to a new tenant.
    When the policy receives a 401, it should update the cached challenge for the requested URL, if one exists.
    """

    # ensure the test starts with an empty cache
    HttpChallengeCache.clear()

    url = "https://azure.service/path"
    first_scope = "https://first-scope"
    first_token = "first-scope-token"
    second_scope = "https://second-scope"
    second_token = "second-scope-token"
    challenge_fmt = 'Bearer authorization="https://login.authority.net/tenant", resource={}'

    # mocking a tenant change:
    # 1. first request -> respond with challenge
    # 2. second request should be authorized according to the challenge -> respond with success
    # 3. third request should match the second -> respond with a new challenge
    # 4. fourth request should be authorized according to the new challenge -> respond with success
    # 5. fifth request should match the fourth -> respond with success
    transport = validating_transport(
        requests=(
            Request(url),
            Request(url, required_headers={"Authorization": "Bearer {}".format(first_token)}),
            Request(url, required_headers={"Authorization": "Bearer {}".format(first_token)}),
            Request(url, required_headers={"Authorization": "Bearer {}".format(second_token)}),
            Request(url, required_headers={"Authorization": "Bearer {}".format(second_token)}),
        ),
        responses=(
            mock_response(status_code=401, headers={"WWW-Authenticate": challenge_fmt.format(first_scope)}),
            mock_response(status_code=200),
            mock_response(status_code=401, headers={"WWW-Authenticate": challenge_fmt.format(second_scope)}),
            mock_response(status_code=200),
            mock_response(status_code=200),
        ),
    )

    tokens = (t for t in [first_token] * 2 + [second_token] * 2)
    credential = Mock(get_token=lambda _: AccessToken(next(tokens), 0))
    pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=transport)

    # policy should complete and cache the first challenge
    pipeline.run(HttpRequest("GET", url))

    # The next request will receive a challenge. The policy should handle it and update the cache entry.
    pipeline.run(HttpRequest("GET", url))
예제 #12
0
class MockClient:
    @distributed_trace
    def __init__(self, policies=None, assert_current_span=False):
        time.sleep(0.001)
        self.request = HttpRequest("GET", "https://bing.com")
        if policies is None:
            policies = []
        policies.append(mock.Mock(spec=HTTPPolicy, send=self.verify_request))
        self.policies = policies
        self.transport = mock.Mock(spec=HttpTransport)
        self.pipeline = Pipeline(self.transport, policies=policies)

        self.expected_response = mock.Mock(spec=PipelineResponse)
        self.assert_current_span = assert_current_span

    def verify_request(self, request):
        current_span = tracing_context.current_span.get()
        if self.assert_current_span:
            assert current_span is not None
        return self.expected_response

    @distributed_trace
    def make_request(self, numb_times, **kwargs):
        time.sleep(0.001)
        if numb_times < 1:
            return None
        response = self.pipeline.run(self.request, **kwargs)
        self.get_foo()
        self.make_request(numb_times - 1, **kwargs)
        return response

    @distributed_trace
    def get_foo(self):
        time.sleep(0.001)
        return 5
예제 #13
0
def test_does_not_sleep_after_timeout(transport_error, expected_timeout_error):
    # With default settings policy will sleep twice before exhausting its retries: 1.6s, 3.2s.
    # It should not sleep the second time when given timeout=1
    timeout = 1

    transport = Mock(
        spec=HttpTransport,
        send=Mock(side_effect=transport_error("oops")),
        sleep=Mock(wraps=time.sleep),
    )
    pipeline = Pipeline(transport, [RetryPolicy(timeout=timeout)])

    with pytest.raises(expected_timeout_error):
        pipeline.run(HttpRequest("GET", "http://127.0.0.1/"))

    assert transport.sleep.call_count == 1
def test_multiple_claims_challenges():
    """ARMChallengeAuthenticationPolicy should not attempt to handle a response having multiple claims challenges"""

    expected_header = ",".join(
        (
            'Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", client_id="00000003-0000-0000-c000-000000000000", error="insufficient_claims", claims="eyJhY2Nlc3NfdG9rZW4iOiB7ImZvbyI6ICJiYXIifX0="',
            'Bearer authorization_uri="https://login.windows-ppe.net/", error="invalid_token", error_description="User session has been revoked", claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0="',
        )
    )

    def send(request):
        return Mock(status_code=401, headers={"WWW-Authenticate": expected_header})

    transport = Mock(send=Mock(wraps=send))
    credential = Mock()
    policies = [ARMChallengeAuthenticationPolicy(credential, "scope")]
    pipeline = Pipeline(transport=transport, policies=policies)

    response = pipeline.run(HttpRequest("GET", "https://localhost"))

    assert transport.send.call_count == 1
    assert credential.get_token.call_count == 1

    # the policy should have returned the error response because it was unable to handle the challenge
    assert response.http_response.status_code == 401
    assert response.http_response.headers["WWW-Authenticate"] == expected_header
예제 #15
0
def test_preserves_options_and_headers():
    """After a challenge, the original request should be sent with its options and headers preserved.

    If a policy mutates the options or headers of the challenge (unauthorized) request, the options of the service
    request should be present when it is sent with authorization.
    """

    url = get_random_url()
    token = "**"

    def get_token(*_, **__):
        return AccessToken(token, 0)

    credential = Mock(get_token=Mock(wraps=get_token))

    transport = validating_transport(
        requests=[Request()] * 2 + [Request(required_headers={"Authorization": "Bearer " + token})],
        responses=[
            mock_response(
                status_code=401, headers={"WWW-Authenticate": 'Bearer authorization="{}", resource=foo'.format(url)}
            )
        ]
        + [mock_response()] * 2,
    )
    challenge_policy = ChallengeAuthPolicy(credential=credential)
    policies = get_policies_for_request_mutation_test(challenge_policy)
    pipeline = Pipeline(policies=policies, transport=transport)

    response = pipeline.run(HttpRequest("GET", url))

    # ensure the mock sans I/O policies were called
    for policy in policies:
        if hasattr(policy, "on_request"):
            assert policy.on_request.called, "mock policy wasn't invoked"
예제 #16
0
class AuthnClient(AuthnClientBase):
    """Synchronous authentication client.

    :param str auth_url:
    :param config: Optional configuration for the HTTP pipeline.
    :type config: :class:`azure.core.configuration`
    :param policies: Optional policies for the HTTP pipeline.
    :type policies:
    :param transport: Optional HTTP transport.
    :type transport:
    """

    # pylint:disable=missing-client-constructor-parameter-credential
    def __init__(
        self,
        config=None,  # type: Optional[Configuration]
        policies=None,  # type: Optional[Iterable[HTTPPolicy]]
        transport=None,  # type: Optional[HttpTransport]
        **kwargs  # type: Any
    ):
        # type: (...) -> None
        config = config or self._create_config(**kwargs)
        policies = policies or [
            ContentDecodePolicy(),
            config.proxy_policy,
            config.retry_policy,
            config.logging_policy,
            DistributedTracingPolicy(**kwargs),
            HttpLoggingPolicy(**kwargs),
        ]
        if not transport:
            transport = RequestsTransport(**kwargs)
        self._pipeline = Pipeline(transport=transport, policies=policies)
        super(AuthnClient, self).__init__(**kwargs)

    def request_token(
        self,
        scopes,  # type: Iterable[str]
        method="POST",  # type: Optional[str]
        headers=None,  # type: Optional[Mapping[str, str]]
        form_data=None,  # type: Optional[Mapping[str, str]]
        params=None,  # type: Optional[Dict[str, str]]
        **kwargs  # type: Any
    ):
        # type: (...) -> AccessToken
        request = self._prepare_request(method, headers=headers, form_data=form_data, params=params)
        request_time = int(time.time())
        response = self._pipeline.run(request, stream=False, **kwargs)
        token = self._deserialize_and_cache_token(response=response, scopes=scopes, request_time=request_time)
        return token

    @staticmethod
    def _create_config(**kwargs):
        # type: (Mapping[str, Any]) -> Configuration
        config = Configuration(**kwargs)
        config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
        config.retry_policy = RetryPolicy(**kwargs)
        config.proxy_policy = ProxyPolicy(**kwargs)
        return config
class MockClient:
    @distributed_trace
    def __init__(self, http_request, policies=None, assert_current_span=False):
        time.sleep(0.001)
        self.request = http_request("GET", "http://localhost")
        if policies is None:
            policies = []
        policies.append(mock.Mock(spec=HTTPPolicy, send=self.verify_request))
        self.policies = policies
        self.transport = mock.Mock(spec=HttpTransport)
        self.pipeline = Pipeline(self.transport, policies=policies)

        self.expected_response = mock.Mock(spec=PipelineResponse)
        self.assert_current_span = assert_current_span

    def verify_request(self, request):
        if self.assert_current_span:
            assert execution_context.get_current_span() is not None
        return self.expected_response

    @distributed_trace_async
    async def make_request(self, numb_times, **kwargs):
        time.sleep(0.001)
        if numb_times < 1:
            return None
        response = self.pipeline.run(self.request, **kwargs)
        await self.get_foo(merge_span=True)
        kwargs['merge_span'] = True
        await self.make_request(numb_times - 1, **kwargs)
        return response

    @distributed_trace_async
    async def merge_span_method(self):
        return await self.get_foo(merge_span=True)

    @distributed_trace_async
    async def no_merge_span_method(self):
        return await self.get_foo()

    @distributed_trace_async
    async def get_foo(self):
        time.sleep(0.001)
        return 5

    @distributed_trace_async(name_of_span="different name")
    async def check_name_is_different(self):
        time.sleep(0.001)

    @distributed_trace_async(tracing_attributes={'foo': 'bar'})
    async def tracing_attr(self):
        time.sleep(0.001)

    @distributed_trace_async(kind=SpanKind.PRODUCER)
    async def kind_override(self):
        time.sleep(0.001)

    @distributed_trace_async
    async def raising_exception(self):
        raise ValueError("Something went horribly wrong here")
def test_token_expiration():
    """policy should not use a cached token which has expired"""

    url = get_random_url()

    expires_on = time.time() + 3600
    first_token = "*"
    second_token = "**"

    token = AccessToken(first_token, expires_on)

    def get_token(*_, **__):
        return token

    credential = Mock(get_token=Mock(wraps=get_token))
    transport = validating_transport(
        requests=[
            Request(),
            Request(
                required_headers={"Authorization": "Bearer " + first_token}),
            Request(
                required_headers={"Authorization": "Bearer " + first_token}),
            Request(
                required_headers={"Authorization": "Bearer " + second_token}),
        ],
        responses=[
            mock_response(
                status_code=401,
                headers={
                    "WWW-Authenticate":
                    'Bearer authorization="{}", resource=foo'.format(url)
                })
        ] + [mock_response()] * 3,
    )
    pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)],
                        transport=transport)

    for _ in range(2):
        pipeline.run(HttpRequest("GET", url))
        assert credential.get_token.call_count == 1

    token = AccessToken(second_token, time.time() + 3600)
    with patch("time.time", lambda: expires_on):
        pipeline.run(HttpRequest("GET", url))
    assert credential.get_token.call_count == 2
def test_response_streaming_error_behavior(http_response):
    # Test to reproduce https://github.com/Azure/azure-sdk-for-python/issues/16723
    block_size = 103
    total_response_size = 500
    req_response = requests.Response()
    req_request = requests.Request()

    class FakeStreamWithConnectionError:
        # fake object for urllib3.response.HTTPResponse
        def __init__(self):
            self.total_response_size = 500

        def stream(self, chunk_size, decode_content=False):
            assert chunk_size == block_size
            left = total_response_size
            while left > 0:
                if left <= block_size:
                    raise requests.exceptions.ConnectionError()
                data = b"X" * min(chunk_size, left)
                left -= len(data)
                yield data

        def read(self, chunk_size, decode_content=False):
            assert chunk_size == block_size
            if self.total_response_size > 0:
                if self.total_response_size <= block_size:
                    raise requests.exceptions.ConnectionError()
                data = b"X" * min(chunk_size, self.total_response_size)
                self.total_response_size -= len(data)
                return data

        def close(self):
            pass

    s = FakeStreamWithConnectionError()
    req_response.raw = FakeStreamWithConnectionError()

    response = create_transport_response(
        http_response,
        req_request,
        req_response,
        block_size,
    )

    def mock_run(self, *args, **kwargs):
        return PipelineResponse(
            None,
            requests.Response(),
            None,
        )

    transport = RequestsTransport()
    pipeline = Pipeline(transport)
    pipeline.run = mock_run
    downloader = response.stream_download(pipeline, decompress=False)
    with pytest.raises(requests.exceptions.ConnectionError):
        full_response = b"".join(downloader)