def _prepare_request(self, method, url, *, data=None, params=None, headers=None, service=NEPTUNE_SERVICE_NAME): self._ensure_http_session() request = requests.Request(method=method, url=url, data=data, params=params, headers=headers, auth=self._auth) if self._session is not None: credentials = self._session.get_credentials() frozen_creds = credentials.get_frozen_credentials() req = AWSRequest(method=method, url=url, data=data, params=params, headers=headers) SigV4Auth(frozen_creds, service, self.region).add_auth(req) prepared_iam_req = req.prepare() request.headers = dict(prepared_iam_req.headers) return request.prepare()
def test_auth_header_preserved_from_s3_redirects(self): request = AWSRequest() request.url = 'https://bucket.s3.amazonaws.com/' request.method = 'GET' request.headers['Authorization'] = 'original auth header' prepared_request = request.prepare() fake_response = Mock() fake_response.headers = { 'location': 'https://bucket.s3-us-west-2.amazonaws.com'} fake_response.url = request.url fake_response.status_code = 307 fake_response.is_permanent_redirect = False # This line is needed to disable the cookie handling # code in requests. fake_response.raw._original_response = None success_response = Mock() success_response.raw._original_response = None success_response.is_redirect = False success_response.status_code = 200 session = BotocoreHTTPSession() session.send = Mock(return_value=success_response) list(session.resolve_redirects( fake_response, prepared_request, stream=False)) redirected_request = session.send.call_args[0][0] # The Authorization header for the newly sent request should # still have our original Authorization header. self.assertEqual( redirected_request.headers['Authorization'], 'original auth header')
def send_response(event, context, response_status, response_data): '''Send a resource manipulation status response to CloudFormation''' response_body = json.dumps({ "Status": response_status, "Reason": "See the details in CloudWatch Log Stream: " + context.log_stream_name, "PhysicalResourceId": context.log_stream_name, "StackId": event['StackId'], "RequestId": event['RequestId'], "LogicalResourceId": event['LogicalResourceId'], "Data": response_data }) # params = '{"name": "hello"}' headers = {'Content-Type': '', 'Content-Length': len(response_data)} print('[INFO] - sending request to %s' % event['ResponseURL']) request = AWSRequest(method="PUT", url=event['ResponseURL'], data=response_body, headers=headers) session = BotocoreHTTPSession() r = session.send(request.prepare()) print('[INFO] - got status_code=%s' % r.status_code)
def test_destination_region_always_changed(self): # If the user provides a destination region, we will still # override the DesinationRegion with the region_name from # the endpoint object. actual_region = 'us-west-1' v4query_auth = mock.Mock() def add_auth(request): request.url += '?PRESIGNED_STUFF' v4query_auth.add_auth = add_auth request_signer = mock.Mock() request_signer._region_name = actual_region request_signer.get_auth.return_value = v4query_auth endpoint = mock.Mock() request = AWSRequest() request.method = 'POST' request.url = 'https://ec2.us-east-1.amazonaws.com' request = request.prepare() endpoint.create_request.return_value = request # The user provides us-east-1, but we will override this to # endpoint.region_name, of 'us-west-1' in this case. params = {'SourceRegion': 'us-west-2', 'DestinationRegion': 'us-east-1'} handlers.copy_snapshot_encrypted({'body': params}, request_signer, endpoint) # Always use the DestinationRegion from the endpoint, regardless of # whatever value the user provides. self.assertEqual(params['DestinationRegion'], actual_region)
def post_data_to_es(payload, region, creds, host, path, method='POST', proto='https://'): print("URL:{}".format(proto + host + path)) req = AWSRequest(method=method, url=proto + host + path, data=payload, headers={ 'Host': host, 'Content-Type': 'application/json' }) SigV4Auth(creds, 'es', region).add_auth(req) http_session = BotocoreHTTPSession() res = http_session.send(req.prepare()) print("STATUS_CODE:{}".format(res.status_code)) print("CONTENT:{}".format(res._content)) print("ALL:{}".format(res)) if res.status_code >= 200 and res.status_code <= 299: return res._content else: raise ES_Exception(res.status_code, res._content)
def _get_aws_request(self, method, url, *, data=None, params=None, headers=None, service=NEPTUNE_SERVICE_NAME): req = AWSRequest(method=method, url=url, data=data, params=params, headers=headers) if self.iam_enabled: credentials = self._session.get_credentials() try: frozen_creds = credentials.get_frozen_credentials() except AttributeError: print( "Could not find valid IAM credentials in any the following locations:\n" ) print( "env, assume-role, assume-role-with-web-identity, sso, shared-credential-file, custom-process, " "config-file, ec2-credentials-file, boto-config, container-role, iam-role\n" ) print( "Go to https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html for more " "details on configuring your IAM credentials.") return req SigV4Auth(frozen_creds, service, self.region).add_auth(req) prepared_iam_req = req.prepare() return prepared_iam_req else: return req
def test_copy_snapshot_encrypted(self): v4query_auth = mock.Mock() def add_auth(request): request.url += '?PRESIGNED_STUFF' v4query_auth.add_auth = add_auth request_signer = mock.Mock() request_signer._region_name = 'us-east-1' request_signer.get_auth.return_value = v4query_auth params = {'SourceRegion': 'us-west-2'} endpoint = mock.Mock() request = AWSRequest() request.method = 'POST' request.url = 'https://ec2.us-east-1.amazonaws.com' request = request.prepare() endpoint.create_request.return_value = request handlers.copy_snapshot_encrypted({'body': params}, request_signer, endpoint) self.assertEqual(params['PresignedUrl'], 'https://ec2.us-west-2.amazonaws.com?PRESIGNED_STUFF') # We should also populate the DestinationRegion with the # region_name of the endpoint object. self.assertEqual(params['DestinationRegion'], 'us-east-1')
def request(url, method, credentials, service_name, region=None, headers=None, data=None): if not region: region = os.environ["AWS_REGION"] aws_request = AWSRequest(url=url, method=method, headers=headers, data=data) SigV4Auth(credentials, service_name, region).add_auth(aws_request) return PreserveAuthSession().send(aws_request.prepare())
def test_auth_header_preserved_from_s3_redirects(self): request = AWSRequest() request.url = 'https://bucket.s3.amazonaws.com/' request.method = 'GET' request.headers['Authorization'] = 'original auth header' prepared_request = request.prepare() fake_response = Mock() fake_response.headers = { 'location': 'https://bucket.s3-us-west-2.amazonaws.com'} fake_response.url = request.url fake_response.status_code = 307 fake_response.is_permanent_redirect = False # This line is needed to disable the cookie handling # code in requests. fake_response.raw._original_response = None success_response = Mock() success_response.raw._original_response = None success_response.is_redirect = False success_response.status_code = 200 session = PreserveAuthSession() session.send = Mock(return_value=success_response) list(session.resolve_redirects( fake_response, prepared_request, stream=False)) redirected_request = session.send.call_args[0][0] # The Authorization header for the newly sent request should # still have our original Authorization header. self.assertEqual( redirected_request.headers['Authorization'], 'original auth header')
def invoke(self, runtime_name, runtime_memory, payload): """ Invoke lambda function asynchronously @param runtime_name: name of the runtime @param runtime_memory: memory of the runtime in MB @param payload: invoke dict payload @return: invocation ID """ function_name = self._format_function_name(runtime_name, runtime_memory) headers = {'Host': self.host, 'X-Amz-Invocation-Type': 'Event', 'User-Agent': self.user_agent} url = f'https://{self.host}/2015-03-31/functions/{function_name}/invocations' request = AWSRequest(method="POST", url=url, data=json.dumps(payload, default=str), headers=headers) SigV4Auth(self.credentials, "lambda", self.region_name).add_auth(request) invoked = False while not invoked: try: r = self.session.send(request.prepare()) invoked = True except Exception: pass if r.status_code == 202: return r.headers['x-amzn-RequestId'] elif r.status_code == 401: logger.debug(r.text) raise Exception('Unauthorized - Invalid API Key') elif r.status_code == 404: logger.debug(r.text) raise Exception('Lithops Runtime: {} not deployed'.format(runtime_name)) else: logger.debug(r.text) raise Exception('Error {}: {}'.format(r.status_code, r.text))
def post_data_to_es(payload, region, creds, host, path, method='POST', proto='https://'): '''Post data to ES endpoint with SigV4 signed http headers''' """ Low-level POST data to Amazon Elasticsearch Service generating a Sigv4 signed request :param payload: :param region: :param creds: :param host: :param path: :param method: :param proto: :return: """ req = AWSRequest(method=method, url=proto + host + urllib.quote(path), data=payload, headers={'Host': host}) SigV4Auth(creds, 'es', region).add_auth(req) http_session = BotocoreHTTPSession() res = http_session.send(req.prepare()) if 200 <= res.status_code <= 299: return res._content else: raise ESException(res.status_code, res._content)
def test_can_prepare_url_params_with_existing_query(self): request = AWSRequest( url='http://example.com/?bar=foo', params={'foo': 'bar'} ) prepared_request = request.prepare() self.assertEqual( prepared_request.url, 'http://example.com/?bar=foo&foo=bar' )
def get_enabled_rules(self) -> List[Dict[str, Any]]: """Gets information for all enabled rules.""" request = AWSRequest(method='GET', url=self.url + '/enabled', params={'type': 'RULE'}) self.signer.add_auth(request) prepped_request = request.prepare() response = requests.get(prepped_request.url, headers=prepped_request.headers) response.raise_for_status() return response.json()['policies']
def send_to_es(self, path, method="GET", payload={}): """Low-level POST data to Amazon Elasticsearch Service generating a Sigv4 signed request Args: path (str): path to send to ES method (str, optional): HTTP method default:GET payload (dict, optional): additional payload used during POST or PUT Returns: dict: json answer converted in dict Raises: #: Error during ES communication ES_Exception: Description """ if not path.startswith("/"): path = "/" + path es_region = self.cfg["es_endpoint"].split(".")[1] headers = { "Host": self.cfg["es_endpoint"], "Content-Type": "application/json" } # send to ES with exponential backoff retries = 0 while retries < int(self.cfg["es_max_retry"]): if retries > 0: seconds = (2**retries) * .1 time.sleep(seconds) req = AWSRequest( method=method, url="https://{}{}".format( self.cfg["es_endpoint"], quote(path)), data=json.dumps(payload), params={"format": "json"}, headers=headers) credential_resolver = create_credential_resolver(get_session()) credentials = credential_resolver.load_credentials() SigV4Auth(credentials, 'es', es_region).add_auth(req) try: preq = req.prepare() session = Session() res = session.send(preq) if res.status_code >= 200 and res.status_code <= 299: return json.loads(res.content) else: raise ES_Exception(res.status_code, res._content) except ES_Exception as e: if (e.status_code >= 500) and (e.status_code <= 599): retries += 1 # Candidate for retry else: raise # Stop retrying, re-raise exception
def post_data_to_es(payload, region, creds, host, path, method='POST', proto='https://'): '''Post data to ES endpoint with SigV4 signed http headers''' req = AWSRequest(method=method, url=proto + host + urllib.quote(path), data=payload, headers={'Host': host, 'Content-Type' : 'application/json'}) SigV4Auth(creds, 'es', region).add_auth(req) http_session = BotocoreHTTPSession() res = http_session.send(req.prepare()) if res.status_code >= 200 and res.status_code <= 299: return res._content else: raise ES_Exception(res.status_code, res._content)
def text(self, input_text, session_attributes=None): """Input text will be passed to your lex bot""" url = self.url + 'text' payload = json.dumps({ "inputText": input_text, "sessionAttributes": session_attributes }) request = AWSRequest(method="POST", url=url, data=payload) SigV4Auth(self.creds, 'lex', self.region).add_auth(request) return self.session.send(request.prepare()).json()
def post_data_to_opensearch(payload, region, creds, host, path, method='POST', proto='https://'): '''Post data to OpenSearch endpoint with SigV4 signed http headers''' req = AWSRequest(method=method, url=proto + host + quote(path), data=payload, headers={'Host': host, 'Content-Type': 'application/json'}) # SigV4Auth may be expecting 'es' but need to swap with 'os' or 'OpenSearch' SigV4Auth(creds, 'es', region).add_auth(req) http_session = BotocoreHTTPSession() res = http_session.send(req.prepare()) if res.status_code >= 200 and res.status_code <= 299: return res._content else: raise Searchable_Exception(res.status_code, res._content)
def test_request_params_not_duplicated_in_prepare(self): """ params should be moved to query string in add_auth and not rewritten at the end with request.prepare() """ request = AWSRequest(method='GET', url='https://ec2.us-east-1.amazonaws.com/', params={'Action': 'MyOperation'}) self.auth.add_auth(request) self.assertIn('?Action=MyOperation&X-Amz', request.url) prep = request.prepare() assert not prep.url.endswith('Action=MyOperation')
class TestAWSRequest(unittest.TestCase): def setUp(self): self.tempdir = tempfile.mkdtemp() self.request = AWSRequest(url='http://example.com') self.prepared_request = self.request.prepare() self.filename = os.path.join(self.tempdir, 'foo') def tearDown(self): shutil.rmtree(self.tempdir) def test_should_reset_stream(self): with open(self.filename, 'wb') as f: f.write(b'foobarbaz') with open(self.filename, 'rb') as body: self.prepared_request.body = body # Now pretend we try to send the request. # This means that we read the body: body.read() # And create a response object that indicates # a redirect. fake_response = Mock() fake_response.status_code = 307 # Then requests calls our reset_stream hook. self.prepared_request.reset_stream_on_redirect(fake_response) # The stream should now be reset. self.assertEqual(body.tell(), 0) def test_cannot_reset_stream_raises_error(self): with open(self.filename, 'wb') as f: f.write(b'foobarbaz') with open(self.filename, 'rb') as body: self.prepared_request.body = Unseekable(body) # Now pretend we try to send the request. # This means that we read the body: body.read() # And create a response object that indicates # a redirect. fake_response = Mock() fake_response.status_code = 307 # Then requests calls our reset_stream hook. with self.assertRaises(UnseekableStreamError): self.prepared_request.reset_stream_on_redirect(fake_response)
def es_request(payload, path, method): # get aws creds session = boto3.Session() params = json.dumps(payload) headers = {"Host": HOST, "Content-Type": "application/json"} request = AWSRequest(method="POST", url=f"https://{HOST}/{path}", data=params, headers=headers) SigV4Auth(boto3.Session().get_credentials(), "es", "us-east-1").add_auth(request) session = URLLib3Session() r = session.send(request.prepare()) return json.loads(r.text)
def get_request(): credentials = Credentials( os.environ['AWS_ACCESS_KEY_ID'], os.environ['AWS_SECRET_ACCESS_KEY'], # os.environ['AWS_SESSION_TOKEN'], ) sigv4 = SigV4Auth(credentials, 'execute-api', 'us-east-1') endpoint = 'https://g4gdlwz33m.execute-api.us-east-1.amazonaws.com/prod' request = AWSRequest(method='GET', url=endpoint) sigv4.add_auth(request) prepped = request.prepare() response = requests.get(prepped.url, headers=prepped.headers) print("GET Request: {}".format(response.text))
def write_http_request(self, path: str, headers) -> None: # Intercept the GET that initiates the websocket protocol at the point where # all of its 'real' headers have been constructed. Add in the sigv4 header AWS needs. credentials = Credentials(os.environ['AWS_ACCESS_KEY_ID'], os.environ['AWS_SECRET_ACCESS_KEY'], os.environ['AWS_SESSION_TOKEN']) sigv4 = SigV4Auth(credentials, 'execute-api', os.environ['AWS_REGION']) request = AWSRequest(method='GET', url='https://' + natpunch_server) sigv4.add_auth(request) prepped = request.prepare() headers['Authorization'] = prepped.headers['Authorization'] headers['X-Amz-Date'] = prepped.headers['X-Amz-Date'] headers['x-amz-security-token'] = prepped.headers[ 'x-amz-security-token'] # Run the original code with the added sigv4 auth header now included: super().write_http_request(path, headers)
def content(self, data, ctype, accept, session_attributes=None): """This will post any content to your lex bot Valid values for ctype and accept are found here: http://docs.aws.amazon.com/lex/latest/dg/API_PostContent.html""" url = self.url + 'content' request = AWSRequest(method="POST", url=url, data=data) request.headers["accept"] = accept request.headers["content-type"] = ctype if session_attributes: request.headers.add_header( "x-amz-lex-session-attributes", base64.b64encode(json.dumps(session_attributes))) LexContentSigV4Auth(self.creds, 'lex', self.region).add_auth(request) prepared = request.prepare() prepared.body = data return self.session.send(prepared)
def head_index_from_es(index_name, method='HEAD', proto='https://'): es_url = urlparse.urlparse(ES_ENDPOINT) es_endpoint = es_url.netloc or es_url.path # Extract the domain name in ES_ENDPOINT '''Post data to ES endpoint with SigV4 signed http headers''' req = AWSRequest(method=method, url=proto + es_endpoint + '/' + urllib.quote(index_name), headers={'Host': es_endpoint}) es_region = ES_REGION or os.environ['AWS_REGION'] session = Session() SigV4Auth(get_credentials(session), 'es', os.environ['AWS_REGION']).add_auth(req) http_session = URLLib3Session() res = http_session.send(req.prepare()) if res.status_code >= 200 and res.status_code <= 299: logger.info('Index %s do exists, continue update setting', index_name) return True else: logger.warning('Index %s does not exists, need to create.', index_name) return False
def lambda_handler(event, context): # Initialise a request object to use for signing. # Make sure we're targetting the right API gateway host in the HTTP header, # especially required if the VPC endpoint DNS name is being used. logger.info("initialising API request to %s (host %s)", API_URL, API_HOST) request = AWSRequest(method="GET", url=API_URL, headers={'host': API_HOST}) # Obtain credentials and use them to sign the request credentials = get_api_credentials() sigv4 = SigV4Auth(credentials, 'execute-api', AWS_REGION) sigv4.add_auth(request) prepreq = request.prepare() logger.info("making request to url %s", prepreq.url) response = requests.get(prepreq.url, headers=prepreq.headers, timeout=API_TIMEOUT) logger.info("response code: %d", response.status_code) logger.info("response text: %s", response.text) return {'statusCode': response.status_code, 'body': response.text}
def post_request(): credentials = Credentials( os.environ['AWS_ACCESS_KEY_ID'], os.environ['AWS_SECRET_ACCESS_KEY'], # os.environ['AWS_SESSION_TOKEN'], ) sigv4 = SigV4Auth(credentials, 'execute-api', 'us-east-1') endpoint = 'https://g4gdlwz33m.execute-api.us-east-1.amazonaws.com/prod' data = {"My": "body"} headers = {'Content-Type': 'application/json'} request = AWSRequest(method='POST', url=endpoint, data=data, headers=headers) sigv4.add_auth(request) prepped = request.prepare() response = requests.post(prepped.url, headers=prepped.headers, data=data) print("POST Request: {}".format(response.text))
def put_data_to_es(payload, path, method='PUT', proto='https://'): es_url = urlparse.urlparse(ES_ENDPOINT) es_endpoint = es_url.netloc or es_url.path # Extract the domain name in ES_ENDPOINT '''Post data to ES endpoint with SigV4 signed http headers''' req = AWSRequest(method=method, url=proto + es_endpoint + '/' + urllib.quote(path), data=payload, headers={ 'Host': es_endpoint, 'Content-Type': 'application/json' }) es_region = ES_REGION or os.environ['AWS_REGION'] session = Session() SigV4Auth(get_credentials(session), 'es', os.environ['AWS_REGION']).add_auth(req) http_session = URLLib3Session() res = http_session.send(req.prepare()) if res.status_code >= 200 and res.status_code <= 299: return res._content else: raise ES_Exception(res.status_code, res._content)
def request(option): logger.debug('option:{}'.format(option)) request = AWSRequest(method="GET", url=os.environ["ES_ENDPOINT_URL"], data=json.dumps(option)) if ("AWS_ACCESS_KEY_ID" in os.environ): credentials = Credentials(os.environ["AWS_ACCESS_KEY_ID"], os.environ["AWS_SECRET_ACCESS_KEY"], os.environ["AWS_SESSION_TOKEN"]) SigV4Auth(credentials, "es", os.environ["AWS_REGION"]).add_auth(request) response = BotocoreHTTPSession().send(request.prepare()) result = response.json() logger.debug('result:{}'.format(result)) if (("hits" in result) and ("hits" in result["hits"])): return list(map(lambda n: n["_source"], result["hits"]["hits"])) else: return []
def __call__(self, request): # Parse request URL url = urlparse(request.url) # Prepare AWS request awsrequest = AWSRequest( method=request.method, url=f'{url.scheme}://{url.netloc}{url.path}', data=request.body, params=dict(parse_qsl(url.query)), ) # Sign request self.sigv4.add_auth(awsrequest) # Re-add original headers for key, val in request.headers.items(): if key not in awsrequest.headers: awsrequest.headers[key] = val # Return prepared request return awsrequest.prepare()
def send_to_es(self, path, method="GET", payload={}): """Low-level POST data to Amazon Elasticsearch Service generating a Sigv4 signed request Args: path (str): path to send to ES method (str, optional): HTTP method default:GET payload (dict, optional): additional payload used during POST or PUT Returns: dict: json answer converted in dict Raises: #: Error during ES communication ESException: Description """ if not path.startswith("/"): path = "/" + path es_region = self.cfg["es_endpoint"].split(".")[1] req = AWSRequest(method=method, url="https://%s%s?pretty&format=json" % (self.cfg["es_endpoint"], quote(path)), data=payload, headers={'Host': self.cfg["es_endpoint"]}) credential_resolver = create_credential_resolver(get_session()) credentials = credential_resolver.load_credentials() SigV4Auth(credentials, 'es', es_region).add_auth(req) preq = req.prepare() session = Session() res = session.send(preq) session.close() if res.status_code >= 200 and res.status_code <= 299: return json.loads(res.content) else: raise ESException(res.status_code, res._content)
class TestAWSRequest(unittest.TestCase): def setUp(self): self.tempdir = tempfile.mkdtemp() self.filename = os.path.join(self.tempdir, 'foo') self.request = AWSRequest(method='GET', url='http://example.com') self.prepared_request = self.request.prepare() def tearDown(self): shutil.rmtree(self.tempdir) def test_prepared_request_repr(self): expected_repr = ( '<AWSPreparedRequest stream_output=False, method=GET, ' 'url=http://example.com, headers={}>' ) request_repr = repr(self.prepared_request) self.assertEqual(request_repr, expected_repr) def test_can_prepare_url_params(self): request = AWSRequest(url='http://example.com/', params={'foo': 'bar'}) prepared_request = request.prepare() self.assertEqual(prepared_request.url, 'http://example.com/?foo=bar') def test_can_prepare_dict_body(self): body = {'dead': 'beef'} request = AWSRequest(url='http://example.com/', data=body) prepared_request = request.prepare() self.assertEqual(prepared_request.body, 'dead=beef') def test_can_prepare_dict_body_unicode_values(self): body = {'Text': u'\u30c6\u30b9\u30c8 string'} expected_body = 'Text=%E3%83%86%E3%82%B9%E3%83%88+string' request = AWSRequest(url='http://example.com/', data=body) prepared_request = request.prepare() self.assertEqual(prepared_request.body, expected_body) def test_can_prepare_dict_body_unicode_keys(self): body = {u'\u30c6\u30b9\u30c8': 'string'} expected_body = '%E3%83%86%E3%82%B9%E3%83%88=string' request = AWSRequest(url='http://example.com/', data=body) prepared_request = request.prepare() self.assertEqual(prepared_request.body, expected_body) def test_can_prepare_empty_body(self): request = AWSRequest(url='http://example.com/', data=b'') prepared_request = request.prepare() self.assertEqual(prepared_request.body, None) content_length = prepared_request.headers.get('content-length') self.assertEqual(content_length, '0') def test_request_body_is_prepared(self): request = AWSRequest(url='http://example.com/', data='body') self.assertEqual(request.body, b'body') def test_prepare_body_content_adds_content_length(self): content = b'foobarbaz' expected_len = str(len(content)) with open(self.filename, 'wb') as f: f.write(content) with open(self.filename, 'rb') as f: data = Seekable(f) self.request.data = data self.request.method = 'POST' prepared_request = self.request.prepare() calculated_len = prepared_request.headers['Content-Length'] self.assertEqual(calculated_len, expected_len) def test_prepare_body_doesnt_override_content_length(self): self.request.method = 'PUT' self.request.headers['Content-Length'] = '20' self.request.data = b'asdf' prepared_request = self.request.prepare() self.assertEqual(prepared_request.headers['Content-Length'], '20') def test_prepare_body_doesnt_set_content_length_head(self): self.request.method = 'HEAD' self.request.data = b'thisshouldntbehere' prepared_request = self.request.prepare() self.assertEqual(prepared_request.headers.get('Content-Length'), None) def test_prepare_body_doesnt_set_content_length_get(self): self.request.method = 'GET' self.request.data = b'thisshouldntbehere' prepared_request = self.request.prepare() self.assertEqual(prepared_request.headers.get('Content-Length'), None) def test_prepare_body_doesnt_set_content_length_options(self): self.request.method = 'OPTIONS' self.request.data = b'thisshouldntbehere' prepared_request = self.request.prepare() self.assertEqual(prepared_request.headers.get('Content-Length'), None) def test_can_reset_stream_handles_binary(self): contents = b'notastream' self.prepared_request.body = contents self.prepared_request.reset_stream() # assert the request body doesn't change after reset_stream is called self.assertEqual(self.prepared_request.body, contents) def test_can_reset_stream_handles_bytearray(self): contents = bytearray(b'notastream') self.prepared_request.body = contents self.prepared_request.reset_stream() # assert the request body doesn't change after reset_stream is called self.assertEqual(self.prepared_request.body, contents) def test_can_reset_stream(self): contents = b'foobarbaz' with open(self.filename, 'wb') as f: f.write(contents) with open(self.filename, 'rb') as body: self.prepared_request.body = body # pretend the request body was partially sent body.read() self.assertNotEqual(body.tell(), 0) # have the prepared request reset its stream self.prepared_request.reset_stream() # the stream should be reset self.assertEqual(body.tell(), 0) def test_cannot_reset_stream_raises_error(self): contents = b'foobarbaz' with open(self.filename, 'wb') as f: f.write(contents) with open(self.filename, 'rb') as body: self.prepared_request.body = Unseekable(body) # pretend the request body was partially sent body.read() self.assertNotEqual(body.tell(), 0) # reset stream should fail with self.assertRaises(UnseekableStreamError): self.prepared_request.reset_stream() def test_duck_type_for_file_check(self): # As part of determining whether or not we can rewind a stream # we first need to determine if the thing is a file like object. # We should not be using an isinstance check. Instead, we should # be using duck type checks. class LooksLikeFile(object): def __init__(self): self.seek_called = False def read(self, amount=None): pass def seek(self, where): self.seek_called = True looks_like_file = LooksLikeFile() self.prepared_request.body = looks_like_file self.prepared_request.reset_stream() # The stream should now be reset. self.assertTrue(looks_like_file.seek_called)
class TestAWSRequest(unittest.TestCase): def setUp(self): self.tempdir = tempfile.mkdtemp() self.request = AWSRequest(url='http://example.com') self.prepared_request = self.request.prepare() self.filename = os.path.join(self.tempdir, 'foo') def tearDown(self): shutil.rmtree(self.tempdir) def test_should_reset_stream(self): with open(self.filename, 'wb') as f: f.write(b'foobarbaz') with open(self.filename, 'rb') as body: self.prepared_request.body = body # Now pretend we try to send the request. # This means that we read the body: body.read() # And create a response object that indicates # a redirect. fake_response = Mock() fake_response.status_code = 307 # Then requests calls our reset_stream hook. self.prepared_request.reset_stream_on_redirect(fake_response) # The stream should now be reset. self.assertEqual(body.tell(), 0) def test_cannot_reset_stream_raises_error(self): with open(self.filename, 'wb') as f: f.write(b'foobarbaz') with open(self.filename, 'rb') as body: self.prepared_request.body = Unseekable(body) # Now pretend we try to send the request. # This means that we read the body: body.read() # And create a response object that indicates # a redirect fake_response = Mock() fake_response.status_code = 307 # Then requests calls our reset_stream hook. with self.assertRaises(UnseekableStreamError): self.prepared_request.reset_stream_on_redirect(fake_response) def test_duck_type_for_file_check(self): # As part of determining whether or not we can rewind a stream # we first need to determine if the thing is a file like object. # We should not be using an isinstance check. Instead, we should # be using duck type checks. class LooksLikeFile(object): def __init__(self): self.seek_called = False def read(self, amount=None): pass def seek(self, where): self.seek_called = True looks_like_file = LooksLikeFile() self.prepared_request.body = looks_like_file fake_response = Mock() fake_response.status_code = 307 # Then requests calls our reset_stream hook. self.prepared_request.reset_stream_on_redirect(fake_response) # The stream should now be reset. self.assertTrue(looks_like_file.seek_called)
def test_can_prepare_url_params(self): request = AWSRequest(url='http://example.com/', params={'foo': 'bar'}) prepared_request = request.prepare() self.assertEqual(prepared_request.url, 'http://example.com/?foo=bar')
def test_can_prepare_dict_body(self): body = {'dead': 'beef'} request = AWSRequest(url='http://example.com/', data=body) prepared_request = request.prepare() self.assertEqual(prepared_request.body, 'dead=beef')
def test_can_prepare_dict_body_unicode_keys(self): body = {u'\u30c6\u30b9\u30c8': 'string'} expected_body = '%E3%83%86%E3%82%B9%E3%83%88=string' request = AWSRequest(url='http://example.com/', data=body) prepared_request = request.prepare() self.assertEqual(prepared_request.body, expected_body)
def test_can_prepare_empty_body(self): request = AWSRequest(url='http://example.com/', data=b'') prepared_request = request.prepare() self.assertEqual(prepared_request.body, None) content_length = prepared_request.headers.get('content-length') self.assertEqual(content_length, '0')
class TestURLLib3Session(unittest.TestCase): def setUp(self): self.request = AWSRequest( method='GET', url='http://example.com/', headers={}, data=b'', ) self.response = Mock() self.response.headers = {} self.response.stream.return_value = b'' self.pool_manager = Mock() self.connection = Mock() self.connection.urlopen.return_value = self.response self.pool_manager.connection_from_url.return_value = self.connection self.pool_patch = patch('botocore.httpsession.PoolManager') self.proxy_patch = patch('botocore.httpsession.proxy_from_url') self.pool_manager_cls = self.pool_patch.start() self.proxy_manager_fun = self.proxy_patch.start() self.pool_manager_cls.return_value = self.pool_manager self.proxy_manager_fun.return_value = self.pool_manager def tearDown(self): self.pool_patch.stop() self.proxy_patch.stop() def assert_request_sent(self, headers=None, body=None, url='/'): if headers is None: headers = {} self.connection.urlopen.assert_called_once_with( method=self.request.method, url=url, body=body, headers=headers, retries=False, assert_same_host=False, preload_content=False, decode_content=False, ) def _assert_manager_call(self, manager, *assert_args, **assert_kwargs): call_kwargs = { 'strict': True, 'maxsize': ANY, 'timeout': ANY, 'ssl_context': ANY, 'socket_options': [], 'cert_file': None, 'key_file': None, } call_kwargs.update(assert_kwargs) manager.assert_called_with(*assert_args, **call_kwargs) def assert_pool_manager_call(self, *args, **kwargs): self._assert_manager_call(self.pool_manager_cls, *args, **kwargs) def assert_proxy_manager_call(self, *args, **kwargs): self._assert_manager_call(self.proxy_manager_fun, *args, **kwargs) def test_forwards_max_pool_size(self): URLLib3Session(max_pool_connections=22) self.assert_pool_manager_call(maxsize=22) def test_forwards_client_cert(self): URLLib3Session(client_cert='/some/cert') self.assert_pool_manager_call(cert_file='/some/cert', key_file=None) def test_forwards_client_cert_and_key_tuple(self): cert = ('/some/cert', '/some/key') URLLib3Session(client_cert=cert) self.assert_pool_manager_call(cert_file=cert[0], key_file=cert[1]) def test_basic_https_proxy_with_client_cert(self): proxies = {'https': 'http://proxy.com'} session = URLLib3Session(proxies=proxies, client_cert='/some/cert') self.request.url = 'https://example.com/' session.send(self.request.prepare()) self.assert_proxy_manager_call( proxies['https'], proxy_headers={}, cert_file='/some/cert', key_file=None, ) self.assert_request_sent() def test_basic_https_proxy_with_client_cert_and_key(self): cert = ('/some/cert', '/some/key') proxies = {'https': 'http://proxy.com'} session = URLLib3Session(proxies=proxies, client_cert=cert) self.request.url = 'https://example.com/' session.send(self.request.prepare()) self.assert_proxy_manager_call( proxies['https'], proxy_headers={}, cert_file=cert[0], key_file=cert[1], ) self.assert_request_sent() def test_basic_request(self): session = URLLib3Session() session.send(self.request.prepare()) self.assert_request_sent() self.response.stream.assert_called_once_with() def test_basic_streaming_request(self): session = URLLib3Session() self.request.stream_output = True session.send(self.request.prepare()) self.assert_request_sent() self.response.stream.assert_not_called() def test_basic_https_request(self): session = URLLib3Session() self.request.url = 'https://example.com/' session.send(self.request.prepare()) self.assert_request_sent() def test_basic_https_proxy_request(self): proxies = {'https': 'http://proxy.com'} session = URLLib3Session(proxies=proxies) self.request.url = 'https://example.com/' session.send(self.request.prepare()) self.assert_proxy_manager_call(proxies['https'], proxy_headers={}) self.assert_request_sent() def test_basic_proxy_request_caches_manager(self): proxies = {'https': 'http://proxy.com'} session = URLLib3Session(proxies=proxies) self.request.url = 'https://example.com/' session.send(self.request.prepare()) # assert we created the proxy manager self.assert_proxy_manager_call(proxies['https'], proxy_headers={}) session.send(self.request.prepare()) # assert that we did not create another proxy manager self.assertEqual(self.proxy_manager_fun.call_count, 1) def test_basic_http_proxy_request(self): proxies = {'http': 'http://proxy.com'} session = URLLib3Session(proxies=proxies) session.send(self.request.prepare()) self.assert_proxy_manager_call(proxies['http'], proxy_headers={}) self.assert_request_sent(url=self.request.url) def test_ssl_context_is_explicit(self): session = URLLib3Session() session.send(self.request.prepare()) _, manager_kwargs = self.pool_manager_cls.call_args self.assertIsNotNone(manager_kwargs.get('ssl_context')) def test_proxy_request_ssl_context_is_explicit(self): proxies = {'http': 'http://proxy.com'} session = URLLib3Session(proxies=proxies) session.send(self.request.prepare()) _, proxy_kwargs = self.proxy_manager_fun.call_args self.assertIsNotNone(proxy_kwargs.get('ssl_context')) def test_session_forwards_socket_options_to_pool_manager(self): socket_options = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)] URLLib3Session(socket_options=socket_options) self.assert_pool_manager_call(socket_options=socket_options) def test_session_forwards_socket_options_to_proxy_manager(self): proxies = {'http': 'http://proxy.com'} socket_options = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)] session = URLLib3Session( proxies=proxies, socket_options=socket_options, ) session.send(self.request.prepare()) self.assert_proxy_manager_call( proxies['http'], proxy_headers={}, socket_options=socket_options, ) def make_request_with_error(self, error): self.connection.urlopen.side_effect = error session = URLLib3Session() session.send(self.request.prepare()) @raises(EndpointConnectionError) def test_catches_new_connection_error(self): error = NewConnectionError(None, None) self.make_request_with_error(error) @raises(ConnectionClosedError) def test_catches_bad_status_line(self): error = ProtocolError(None) self.make_request_with_error(error) def test_aws_connection_classes_are_used(self): session = URLLib3Session() # ensure the pool manager is using the correct classes http_class = self.pool_manager.pool_classes_by_scheme.get('http') self.assertIs(http_class, AWSHTTPConnectionPool) https_class = self.pool_manager.pool_classes_by_scheme.get('https') self.assertIs(https_class, AWSHTTPSConnectionPool)