def __init__( self, session: ClientSession, url: URL, secret_access_key: str = None, access_key_id: str = None, region: str = "", ): access_key_id = access_key_id or url.user secret_access_key = secret_access_key or url.password if not access_key_id: raise ValueError( "access_key id must be passed as argument " "or as username in the url", ) if not secret_access_key: raise ValueError( "secret_access_key id must be passed as argument " "or as username in the url", ) self._url = URL(url).with_user(None).with_password(None) self._session = session self._signer = AwsRequestSigner( region=region, service="s3", access_key_id=access_key_id, secret_access_key=secret_access_key, )
def __init__(self, region: str, access_key_id: str, secret_access_key: str, service: str) -> None: """ Intialize the authentication helper for requests. Use this with the auth argument of the requests methods, or assign it to a session's auth property. :param region: The AWS region to connect to. :param access_key_id: The AWS access key id to use for authentication. :param secret_access_key: The AWS secret access key to use for authentication. :param service: The service to connect to (f.e. `'s3'`). """ self.request_signer = AwsRequestSigner(region, access_key_id, secret_access_key, service)
class AwsAuth(requests.auth.AuthBase): def __init__(self, region: str, access_key_id: str, secret_access_key: str, service: str) -> None: """ Intialize the authentication helper for requests. Use this with the auth argument of the requests methods, or assign it to a session's auth property. :param region: The AWS region to connect to. :param access_key_id: The AWS access key id to use for authentication. :param secret_access_key: The AWS secret access key to use for authentication. :param service: The service to connect to (f.e. `'s3'`). """ self.request_signer = AwsRequestSigner(region, access_key_id, secret_access_key, service) def __call__( self, request: requests.PreparedRequest) -> requests.PreparedRequest: if isinstance(request.body, bytes): content_hash = hashlib.sha256(request.body).hexdigest() else: content_hash = UNSIGNED_PAYLOAD assert isinstance(request.method, str) assert isinstance(request.url, str) auth_headers = self.request_signer.sign_with_headers( request.method, request.url, request.headers, content_hash) request.headers.update(auth_headers) return request
def __init__(self, awsRegion, awsAccessKey, awsSecretKey, s3EndPoint, s3Bucket): if not s3EndPoint.startswith('http://'): raise ValueError('s3EndPoint must start with http://') parsedURL = urlparse(s3EndPoint) self.requestSigner = AwsRequestSigner(awsRegion, awsAccessKey, awsSecretKey, 's3') self.conn = http.client.HTTPConnection(host=parsedURL.hostname, port=parsedURL.port or 80) self.pathBase = '/' + quote(s3Bucket) self.urlBase = s3EndPoint + self.pathBase self.readyForNewFile = True self.totalBytesUploaded = 0 self.bytesUploaded = 0 self.filesUploaded = 0 self.fileSize = None
async def __get_data(self, endpoint: str, params: dict = None) -> Optional[dict]: """Get data from api.""" if params is None: params = {} url = BASE_URL + endpoint headers = BASE_HEADERS.copy() # sign the request creds = await self.__get_credentials() request_signer = AwsRequestSigner(AWS_REGION, creds["access_key"], creds["secret_key"], AWS_SERVICE) headers.update(request_signer.sign_with_headers("GET", url, headers)) headers["X-Amz-Security-Token"] = creds["session_token"] async with self._http_session.get(url, headers=headers, params=params, verify_ssl=False) as response: if response.status != 200: error_msg = await response.text() raise Exception( "Error while retrieving data for endpoint %s: %s" % (endpoint, error_msg)) return await response.json()
class S3Client: def __init__( self, session: ClientSession, url: URL, secret_access_key: str = None, access_key_id: str = None, region: str = "", ): access_key_id = access_key_id or url.user secret_access_key = secret_access_key or url.password if not access_key_id: raise ValueError( "access_key id must be passed as argument " "or as username in the url", ) if not secret_access_key: raise ValueError( "secret_access_key id must be passed as argument " "or as username in the url", ) self._url = URL(url).with_user(None).with_password(None) self._session = session self._signer = AwsRequestSigner( region=region, service="s3", access_key_id=access_key_id, secret_access_key=secret_access_key, ) @property def url(self): return self._url def request( self, method: str, path: str, headers: HeadersType = None, params: ParamsType = None, data: t.Optional[DataType] = None, data_length: t.Optional[int] = None, content_sha256: str = None, **kwargs, ) -> RequestContextManager: if isinstance(data, bytes): data_length = len(data) elif isinstance(data, str): data = data.encode() data_length = len(data) headers = self._prepare_headers(headers) if data_length: headers[hdrs.CONTENT_LENGTH] = str(data_length) elif data is not None: kwargs["chunked"] = True if data is not None and content_sha256 is None: content_sha256 = UNSIGNED_PAYLOAD url = (self._url / path.lstrip("/")) url = url.with_path(quote(url.path), encoded=True).with_query(params) headers = self._make_headers(headers) headers.extend( self._signer.sign_with_headers( method, str(url), headers=headers, content_hash=content_sha256, ), ) return self._session.request( method, url, headers=headers, data=data, **kwargs, ) def get(self, object_name: str, **kwargs) -> RequestContextManager: return self.request("GET", object_name, **kwargs) def head( self, object_name: str, content_sha256=EMPTY_STR_HASH, **kwargs, ) -> RequestContextManager: return self.request( "HEAD", object_name, content_sha256=content_sha256, **kwargs, ) def delete( self, object_name: str, content_sha256=EMPTY_STR_HASH, **kwargs, ) -> RequestContextManager: return self.request( "DELETE", object_name, content_sha256=content_sha256, **kwargs, ) @staticmethod def _make_headers(headers: t.Optional[HeadersType]) -> CIMultiDict: headers = CIMultiDict(headers or {}) return headers def _prepare_headers( self, headers: t.Optional[HeadersType], file_path: str = "", ) -> CIMultiDict: headers = self._make_headers(headers) if hdrs.CONTENT_TYPE not in headers: content_type = guess_type(file_path)[0] if content_type is None: content_type = "application/octet-stream" headers[hdrs.CONTENT_TYPE] = content_type return headers def put( self, object_name: str, data: t.Union[bytes, str, t.AsyncIterable[bytes]], **kwargs, ) -> RequestContextManager: return self.request("PUT", object_name, data=data, **kwargs) def post( self, object_name: str, data: t.Union[None, bytes, str, t.AsyncIterable[bytes]] = None, **kwargs, ): return self.request("POST", object_name, data=data, **kwargs) def put_file( self, object_name: t.Union[str, Path], file_path: t.Union[str, Path], *, headers: HeadersType = None, chunk_size: int = CHUNK_SIZE, content_sha256: str = None, ) -> RequestContextManager: headers = self._prepare_headers(headers, str(file_path)) return self.put( str(object_name), headers=headers, data=async_file_sender( file_path, chunk_size=chunk_size, ), data_length=os.stat(file_path).st_size, content_sha256=content_sha256, ) @asyncbackoff( None, None, 0, max_tries=3, exceptions=(ClientError, ), ) async def _create_multipart_upload( self, object_name: str, headers: HeadersType = None, ) -> str: async with self.post( object_name, headers=headers, params={"uploads": 1}, content_sha256=EMPTY_STR_HASH, ) as resp: payload = await resp.read() if resp.status != HTTPStatus.OK: raise AwsUploadError( f"Wrong status code {resp.status} from s3 with message " f"{payload.decode()}.", ) return parse_create_multipart_upload_id(payload) @asyncbackoff( None, None, 0, max_tries=3, exceptions=(AwsUploadError, ClientError), ) async def _complete_multipart_upload( self, upload_id: str, object_name: str, parts: t.List[t.Tuple[int, str]], ): complete_upload_request = create_complete_upload_request(parts) async with self.post( object_name, headers={"Content-Type": "text/xml"}, params={"uploadId": upload_id}, data=complete_upload_request, content_sha256=hashlib.sha256( complete_upload_request).hexdigest(), ) as resp: if resp.status != HTTPStatus.OK: payload = await resp.text() raise AwsUploadError( f"Wrong status code {resp.status} from s3 with message " f"{payload}.", ) async def _put_part( self, upload_id: str, object_name: str, part_no: int, data: bytes, content_sha256: str, **kwargs, ) -> str: async with self.put( object_name, params={ "partNumber": part_no, "uploadId": upload_id }, data=data, content_sha256=content_sha256, **kwargs, ) as resp: payload = await resp.text() if resp.status != HTTPStatus.OK: raise AwsUploadError( f"Wrong status code {resp.status} from s3 with message " f"{payload}.", ) return resp.headers["Etag"].strip('"') async def _part_uploader( self, upload_id: str, object_name: str, parts_queue: asyncio.Queue, results_queue: deque, part_upload_tries: int, **kwargs, ): backoff = asyncbackoff( None, None, max_tries=part_upload_tries, exceptions=(ClientError, ), ) while True: msg = await parts_queue.get() if msg is DONE: break part_no, part_hash, part = msg etag = await backoff(self._put_part)( # type: ignore upload_id=upload_id, object_name=object_name, part_no=part_no, data=part, content_sha256=part_hash, **kwargs, ) log.debug( "Etag for part %d of %s is %s", part_no, upload_id, etag, ) results_queue.append((part_no, etag)) async def put_file_multipart( self, object_name: t.Union[str, Path], file_path: t.Union[str, Path], *, headers: HeadersType = None, part_size: int = PART_SIZE, workers_count: int = 1, max_size: t.Optional[int] = None, part_upload_tries: int = 3, calculate_content_sha256: bool = True, **kwargs, ): """ Upload data from a file with multipart upload object_name: key in s3 file_path: path to a file for upload headers: additional headers, such as Content-Type part_size: size of a chunk to send (recommended: >5Mb) workers_count: count of coroutines for asyncronous parts uploading max_size: maximum size of a queue with data to send (should be at least `workers_count`) part_upload_tries: how many times trying to put part to s3 before fail calculate_content_sha256: whether to calculate sha256 hash of a part for integrity purposes """ log.debug( "Going to multipart upload %s to %s with part size %d", file_path, object_name, part_size, ) await self.put_multipart( object_name, file_sender( file_path, chunk_size=part_size, ), headers=headers, workers_count=workers_count, max_size=max_size, part_upload_tries=part_upload_tries, calculate_content_sha256=calculate_content_sha256, **kwargs, ) async def _parts_generator( self, gen, workers_count: int, parts_queue: asyncio.Queue, ) -> int: part_no = 1 async with gen: async for part_hash, part in gen: log.debug( "Reading part %d (%d bytes)", part_no, len(part), ) await parts_queue.put((part_no, part_hash, part)) part_no += 1 for _ in range(workers_count): await parts_queue.put(DONE) return part_no async def put_multipart( self, object_name: t.Union[str, Path], data: t.Iterable[bytes], *, headers: HeadersType = None, workers_count: int = 1, max_size: t.Optional[int] = None, part_upload_tries: int = 3, calculate_content_sha256: bool = True, **kwargs, ): """ Send data from iterable with multipart upload object_name: key in s3 data: any iterable that returns chunks of bytes headers: additional headers, such as Content-Type workers_count: count of coroutines for asyncronous parts uploading max_size: maximum size of a queue with data to send (should be at least `workers_count`) part_upload_tries: how many times trying to put part to s3 before fail calculate_content_sha256: whether to calculate sha256 hash of a part for integrity purposes """ if workers_count < 1: raise ValueError( f"Workers count should be > 0. Got {workers_count}", ) max_size = max_size or workers_count upload_id = await self._create_multipart_upload( # type: ignore str(object_name), headers=headers, ) log.debug("Got upload id %s for %s", upload_id, object_name) parts_queue: asyncio.Queue = asyncio.Queue(maxsize=max_size) results_queue: deque = deque() workers = [ asyncio.create_task( self._part_uploader( upload_id, str(object_name), parts_queue, results_queue, part_upload_tries, **kwargs, ), ) for _ in range(workers_count) ] if calculate_content_sha256: gen = gen_with_hash(data) else: gen = gen_without_hash(data) parts_generator = asyncio.create_task( self._parts_generator(gen, workers_count, parts_queue), ) try: part_no, *_ = await asyncio.gather( parts_generator, *workers, ) except Exception: for task in chain([parts_generator], workers): if not task.done(): task.cancel() raise log.debug( "All parts (#%d) of %s are uploaded to %s", part_no - 1, upload_id, object_name, ) # Parts should be in ascending order parts = sorted(results_queue, key=lambda x: x[0]) await self._complete_multipart_upload( # type: ignore upload_id, object_name, parts, ) async def _download_range( self, object_name: str, writer: t.Callable[[bytes, int, int], t.Coroutine], *, etag: str, pos: int, range_start: int, req_range_start: int, req_range_end: int, buffer_size: int, headers: HeadersType = None, **kwargs, ): """ Downloading range [req_range_start:req_range_end] to `file` """ log.debug( "Downloading %s from %d to %d", object_name, req_range_start, req_range_end, ) if not headers: headers = {} headers = headers.copy() headers["Range"] = f"bytes={req_range_start}-{req_range_end}" headers["If-Match"] = etag pos = req_range_start async with self.get(object_name, headers=headers, **kwargs) as resp: if resp.status not in (HTTPStatus.PARTIAL_CONTENT, HTTPStatus.OK): raise AwsDownloadError( f"Got wrong status code {resp.status} on range download " f"of {object_name}", ) while True: chunk = await resp.content.read(buffer_size) if not chunk: break await writer(chunk, range_start, pos) pos += len(chunk) async def _download_worker( self, object_name: str, writer: t.Callable[[bytes, int, int], t.Coroutine], *, etag: str, range_step: int, range_start: int, range_end: int, buffer_size: int, range_get_tries: int = 3, headers: HeadersType = None, **kwargs, ): """ Downloads data in range `[range_start, range_end)` with step `range_step` to file `file_path`. Uses `etag` to make sure that file wasn't changed in the process. """ log.debug( "Starting download worker for range [%d:%d]", range_start, range_end, ) backoff = asyncbackoff( None, None, max_tries=range_get_tries, exceptions=(ClientError, ), ) req_range_end = range_start for req_range_start in range(range_start, range_end, range_step): req_range_end += range_step if req_range_end > range_end: req_range_end = range_end await backoff(self._download_range)( # type: ignore object_name, # type: ignore writer, etag=etag, pos=(req_range_start - range_start), range_start=range_start, req_range_start=req_range_start, req_range_end=req_range_end - 1, buffer_size=buffer_size, headers=headers, **kwargs, ) async def get_file_parallel( self, object_name: t.Union[str, Path], file_path: t.Union[str, Path], *, headers: HeadersType = None, range_step: int = PART_SIZE, workers_count: int = 1, range_get_tries: int = 3, buffer_size: int = PAGESIZE * 32, **kwargs, ): """ Download object in parallel with requests with Range. If file will change while download is in progress - error will be raised. object_name: s3 key to download file_path: target file path headers: additional headers range_step: how much data will be downloaded in single HTTP request workers_count: count of parallel workers range_get_tries: count of tries to download each range buffer_size: size of a buffer for on the fly data """ file_path = Path(file_path) async with self.head(str(object_name)) as resp: if resp.status != HTTPStatus.OK: raise AwsDownloadError( f"Got response for HEAD request for {object_name}" f"of a wrong status {resp.status}", ) etag = resp.headers["Etag"].strip('"') file_size = int(resp.headers["Content-Length"]) log.debug( "Object's %s etag is %s and size is %d", object_name, etag, file_size, ) workers = [] files = [] worker_range_size = file_size // workers_count range_end = 0 try: with file_path.open("w+b") as fp: for range_start in range(0, file_size, worker_range_size): range_end += worker_range_size if range_end > file_size: range_end = file_size if hasattr(os, "pwrite"): writer = partial(pwrite_absolute_pos, fp.fileno()) else: if range_start: file = NamedTemporaryFile(dir=file_path.parent) files.append(file) else: file = fp writer = partial(write_from_start, file) workers.append( self._download_worker( str(object_name), writer, # type: ignore etag=etag, range_step=range_step, range_start=range_start, range_end=range_end, range_get_tries=range_get_tries, buffer_size=buffer_size, **kwargs, ), ) await asyncio.gather(*workers) if files: # First part already in `file_path` log.debug("Joining %d parts to %s", len(files) + 1, file_path) await concat_files( file_path, files, buffer_size=buffer_size, ) except Exception: log.exception( "Error on file download. Removing possibly incomplete file %s", file_path, ) with suppress(FileNotFoundError): os.unlink(file_path) raise
class S3FileUpload: def __init__(self, awsRegion, awsAccessKey, awsSecretKey, s3EndPoint, s3Bucket): if not s3EndPoint.startswith('http://'): raise ValueError('s3EndPoint must start with http://') parsedURL = urlparse(s3EndPoint) self.requestSigner = AwsRequestSigner(awsRegion, awsAccessKey, awsSecretKey, 's3') self.conn = http.client.HTTPConnection(host=parsedURL.hostname, port=parsedURL.port or 80) self.pathBase = '/' + quote(s3Bucket) self.urlBase = s3EndPoint + self.pathBase self.readyForNewFile = True self.totalBytesUploaded = 0 self.bytesUploaded = 0 self.filesUploaded = 0 self.fileSize = None def startFileSend(self, key, fileSize): if not self.readyForNewFile: self.abort() headers = { "Content-Type": "application/octet-stream", "Content-Length": str(fileSize) } headers.update( self.requestSigner.sign_with_headers( "PUT", self.urlBase + quote(key), headers, content_hash=UNSIGNED_PAYLOAD)) self.bytesUploaded = 0 self.fileSize = fileSize self.readyForNewFile = False self.conn.putrequest( method='PUT', url=self.pathBase + quote(key)) # will auto open socket to s3 server if not open for k, v in headers.items(): self.conn.putheader(k, v) self.conn.endheaders() def sendFileData(self, data): if self.readyForNewFile: raise http.client.HTTPException( 'Must call startFileSend before sending data') if data and len(data) > 0: self.conn.send(data) self.bytesUploaded += len(data) self.totalBytesUploaded += len(data) def endFileSend(self): # print('size =', self.fileSize, '; uploaded =', self.bytesUploaded) if self.bytesUploaded < self.fileSize: self.abort() raise http.client.HTTPException('File send was aborted') res = self.conn.getresponse() respBody = res.read() self.readyForNewFile = True if res.status < 200 or res.status > 299: raise http.client.HTTPException('status: ' + str(res.status) + ': ' + str(respBody)) self.filesUploaded += 1 def abort(self): self.conn.close() self.readyForNewFile = True
def __init__( self, host="localhost", port=None, http_auth=None, use_ssl=False, verify_certs=VERIFY_CERTS_DEFAULT, ssl_show_warn=SSL_SHOW_WARN_DEFAULT, ca_certs=None, client_cert=None, client_key=None, ssl_version=None, ssl_assert_fingerprint=None, maxsize=10, headers=None, ssl_context=None, http_compress=None, cloud_id=None, api_key=None, opaque_id=None, loop=None, **kwargs, ): """ Default connection class for ``AsyncElasticsearch`` using the `aiohttp` library and the http protocol. :arg host: hostname of the node (default: localhost) :arg port: port to use (integer, default: 9200) :arg url_prefix: optional url prefix for elasticsearch :arg timeout: default timeout in seconds (float, default: 10) :arg http_auth: optional http auth information as either ':' separated string or a tuple :arg use_ssl: use ssl for the connection if `True` :arg verify_certs: whether to verify SSL certificates :arg ssl_show_warn: show warning when verify certs is disabled :arg ca_certs: optional path to CA bundle. See https://urllib3.readthedocs.io/en/latest/security.html#using-certifi-with-urllib3 for instructions how to get default set :arg client_cert: path to the file containing the private key and the certificate, or cert only if using client_key :arg client_key: path to the file containing the private key if using separate cert and key files (client_cert will contain only the cert) :arg ssl_version: version of the SSL protocol to use. Choices are: SSLv23 (default) SSLv2 SSLv3 TLSv1 (see ``PROTOCOL_*`` constants in the ``ssl`` module for exact options for your environment). :arg ssl_assert_hostname: use hostname verification if not `False` :arg ssl_assert_fingerprint: verify the supplied certificate fingerprint if not `None` :arg maxsize: the number of connections which will be kept open to this host. See https://urllib3.readthedocs.io/en/1.4/pools.html#api for more information. :arg headers: any custom http headers to be add to requests :arg http_compress: Use gzip compression :arg cloud_id: The Cloud ID from ElasticCloud. Convenient way to connect to cloud instances. Other host connection params will be ignored. :arg api_key: optional API Key authentication as either base64 encoded string or a tuple. :arg opaque_id: Send this value in the 'X-Opaque-Id' HTTP header For tracing all requests made by this transport. :arg loop: asyncio Event Loop to use with aiohttp. This is set by default to the currently running loop. """ self.headers = {} super().__init__( host=host, port=port, use_ssl=use_ssl, headers=headers, http_compress=http_compress, cloud_id=cloud_id, api_key=api_key, opaque_id=opaque_id, **kwargs, ) if http_auth is not None: if isinstance(http_auth, (tuple, list)): http_auth = tuple(http_auth) self.session.auth = http_auth # if providing an SSL context, raise error if any other SSL related flag is used if ssl_context and ((verify_certs is not VERIFY_CERTS_DEFAULT) or (ssl_show_warn is not SSL_SHOW_WARN_DEFAULT) or ca_certs or client_cert or client_key or ssl_version): warnings.warn( "When using `ssl_context`, all other SSL related kwargs are ignored" ) self.ssl_assert_fingerprint = ssl_assert_fingerprint if self.use_ssl and ssl_context is None: ssl_context = ssl.SSLContext(ssl_version or ssl.PROTOCOL_TLS) # Convert all sentinel values to their actual default # values if not using an SSLContext. if verify_certs is VERIFY_CERTS_DEFAULT: verify_certs = True if ssl_show_warn is SSL_SHOW_WARN_DEFAULT: ssl_show_warn = True if verify_certs: ssl_context.verify_mode = ssl.CERT_REQUIRED ssl_context.check_hostname = True else: ssl_context.verify_mode = ssl.CERT_NONE ssl_context.check_hostname = False ca_certs = CA_CERTS if ca_certs is None else ca_certs if verify_certs: if not ca_certs: raise ImproperlyConfigured( "Root certificates are missing for certificate " "validation. Either pass them in using the ca_certs parameter or " "install certifi to use it automatically.") else: if ssl_show_warn: warnings.warn( "Connecting to %s using SSL with verify_certs=False is insecure." % self.host) if os.path.isfile(ca_certs): ssl_context.load_verify_locations(cafile=ca_certs) elif os.path.isdir(ca_certs): ssl_context.load_verify_locations(capath=ca_certs) else: raise ImproperlyConfigured("ca_certs parameter is not a path") # Use client_cert and client_key variables for SSL certificate configuration. if client_cert and not os.path.isfile(client_cert): raise ImproperlyConfigured( "client_cert is not a path to a file") if client_key and not os.path.isfile(client_key): raise ImproperlyConfigured( "client_key is not a path to a file") if client_cert and client_key: ssl_context.load_cert_chain(client_cert, client_key) elif client_cert: ssl_context.load_cert_chain(client_cert) self.headers.setdefault("connection", "keep-alive") self.loop = loop self.session = None # Parameters for creating an aiohttp.ClientSession later. self._limit = maxsize self._http_auth = http_auth self._ssl_context = ssl_context access_key = kwargs.get('access_key', None) secret_key = kwargs.get('secret_key', None) service = kwargs.get('service', None) region = kwargs.get('region', None) self.request_signer = AwsRequestSigner(region, access_key, secret_key, service)
class AIOHttpConnection(AsyncConnection): HTTP_CLIENT_META = ("ai", _client_meta_version(aiohttp.__version__)) def __init__( self, host="localhost", port=None, http_auth=None, use_ssl=False, verify_certs=VERIFY_CERTS_DEFAULT, ssl_show_warn=SSL_SHOW_WARN_DEFAULT, ca_certs=None, client_cert=None, client_key=None, ssl_version=None, ssl_assert_fingerprint=None, maxsize=10, headers=None, ssl_context=None, http_compress=None, cloud_id=None, api_key=None, opaque_id=None, loop=None, **kwargs, ): """ Default connection class for ``AsyncElasticsearch`` using the `aiohttp` library and the http protocol. :arg host: hostname of the node (default: localhost) :arg port: port to use (integer, default: 9200) :arg url_prefix: optional url prefix for elasticsearch :arg timeout: default timeout in seconds (float, default: 10) :arg http_auth: optional http auth information as either ':' separated string or a tuple :arg use_ssl: use ssl for the connection if `True` :arg verify_certs: whether to verify SSL certificates :arg ssl_show_warn: show warning when verify certs is disabled :arg ca_certs: optional path to CA bundle. See https://urllib3.readthedocs.io/en/latest/security.html#using-certifi-with-urllib3 for instructions how to get default set :arg client_cert: path to the file containing the private key and the certificate, or cert only if using client_key :arg client_key: path to the file containing the private key if using separate cert and key files (client_cert will contain only the cert) :arg ssl_version: version of the SSL protocol to use. Choices are: SSLv23 (default) SSLv2 SSLv3 TLSv1 (see ``PROTOCOL_*`` constants in the ``ssl`` module for exact options for your environment). :arg ssl_assert_hostname: use hostname verification if not `False` :arg ssl_assert_fingerprint: verify the supplied certificate fingerprint if not `None` :arg maxsize: the number of connections which will be kept open to this host. See https://urllib3.readthedocs.io/en/1.4/pools.html#api for more information. :arg headers: any custom http headers to be add to requests :arg http_compress: Use gzip compression :arg cloud_id: The Cloud ID from ElasticCloud. Convenient way to connect to cloud instances. Other host connection params will be ignored. :arg api_key: optional API Key authentication as either base64 encoded string or a tuple. :arg opaque_id: Send this value in the 'X-Opaque-Id' HTTP header For tracing all requests made by this transport. :arg loop: asyncio Event Loop to use with aiohttp. This is set by default to the currently running loop. """ self.headers = {} super().__init__( host=host, port=port, use_ssl=use_ssl, headers=headers, http_compress=http_compress, cloud_id=cloud_id, api_key=api_key, opaque_id=opaque_id, **kwargs, ) if http_auth is not None: if isinstance(http_auth, (tuple, list)): http_auth = tuple(http_auth) self.session.auth = http_auth # if providing an SSL context, raise error if any other SSL related flag is used if ssl_context and ((verify_certs is not VERIFY_CERTS_DEFAULT) or (ssl_show_warn is not SSL_SHOW_WARN_DEFAULT) or ca_certs or client_cert or client_key or ssl_version): warnings.warn( "When using `ssl_context`, all other SSL related kwargs are ignored" ) self.ssl_assert_fingerprint = ssl_assert_fingerprint if self.use_ssl and ssl_context is None: ssl_context = ssl.SSLContext(ssl_version or ssl.PROTOCOL_TLS) # Convert all sentinel values to their actual default # values if not using an SSLContext. if verify_certs is VERIFY_CERTS_DEFAULT: verify_certs = True if ssl_show_warn is SSL_SHOW_WARN_DEFAULT: ssl_show_warn = True if verify_certs: ssl_context.verify_mode = ssl.CERT_REQUIRED ssl_context.check_hostname = True else: ssl_context.verify_mode = ssl.CERT_NONE ssl_context.check_hostname = False ca_certs = CA_CERTS if ca_certs is None else ca_certs if verify_certs: if not ca_certs: raise ImproperlyConfigured( "Root certificates are missing for certificate " "validation. Either pass them in using the ca_certs parameter or " "install certifi to use it automatically.") else: if ssl_show_warn: warnings.warn( "Connecting to %s using SSL with verify_certs=False is insecure." % self.host) if os.path.isfile(ca_certs): ssl_context.load_verify_locations(cafile=ca_certs) elif os.path.isdir(ca_certs): ssl_context.load_verify_locations(capath=ca_certs) else: raise ImproperlyConfigured("ca_certs parameter is not a path") # Use client_cert and client_key variables for SSL certificate configuration. if client_cert and not os.path.isfile(client_cert): raise ImproperlyConfigured( "client_cert is not a path to a file") if client_key and not os.path.isfile(client_key): raise ImproperlyConfigured( "client_key is not a path to a file") if client_cert and client_key: ssl_context.load_cert_chain(client_cert, client_key) elif client_cert: ssl_context.load_cert_chain(client_cert) self.headers.setdefault("connection", "keep-alive") self.loop = loop self.session = None # Parameters for creating an aiohttp.ClientSession later. self._limit = maxsize self._http_auth = http_auth self._ssl_context = ssl_context access_key = kwargs.get('access_key', None) secret_key = kwargs.get('secret_key', None) service = kwargs.get('service', None) region = kwargs.get('region', None) self.request_signer = AwsRequestSigner(region, access_key, secret_key, service) async def perform_request(self, method, url, params=None, body=None, timeout=None, ignore=(), headers=None): if self.session is None: await self._create_aiohttp_session() assert self.session is not None orig_body = body url_path = self.url_prefix + url if params: query_string = urlencode(params) else: query_string = "" # There is a bug in aiohttp that disables the re-use # of the connection in the pool when method=HEAD. # See: aio-libs/aiohttp#1769 is_head = False if method == "HEAD": method = "GET" is_head = True # Top-tier tip-toeing happening here. Basically # because Pip's old resolver is bad and wipes out # strict pins in favor of non-strict pins of extras # our [async] extra overrides aiohttp's pin of # yarl. yarl released breaking changes, aiohttp pinned # defensively afterwards, but our users don't get # that nice pin that aiohttp set. :( So to play around # this super-defensively we try to import yarl, if we can't # then we pass a string into ClientSession.request() instead. if yarl: # Provide correct URL object to avoid string parsing in low-level code url = yarl.URL.build( scheme=self.scheme, host=self.hostname, port=self.port, path=url_path, query_string=query_string, encoded=True, ) else: url = self.url_prefix + url if query_string: url = "%s?%s" % (url, query_string) url = self.host + url timeout = aiohttp.ClientTimeout( total=timeout if timeout is not None else self.timeout) req_headers = self.headers.copy() if headers: req_headers.update(headers) if self.http_compress and body: body = self._gzip_compress(body) req_headers["content-encoding"] = "gzip" content_hash = hashlib.sha256( json.dumps(body)).hexdigest() if isinstance( body, dict) else hashlib.sha256(body).hexdigest() req_headers.update( self.request_signer.sign_with_headers(method, url, headers=req_headers, content_hash=content_hash)) start = self.loop.time() try: async with self.session.request( method, url, data=body, headers=req_headers, timeout=timeout, fingerprint=self.ssl_assert_fingerprint, ) as response: if is_head: # We actually called 'GET' so throw away the data. await response.release() raw_data = "" else: raw_data = await response.text() duration = self.loop.time() - start # We want to reraise a cancellation. except asyncio.CancelledError: raise except Exception as e: self.log_request_fail( method, str(url), url_path, orig_body, self.loop.time() - start, exception=e, ) if isinstance(e, aiohttp_exceptions.ServerFingerprintMismatch): raise SSLError("N/A", str(e), e) if isinstance( e, (asyncio.TimeoutError, aiohttp_exceptions.ServerTimeoutError)): raise ConnectionTimeout("TIMEOUT", str(e), e) raise ConnectionError("N/A", str(e), e) # raise warnings if any from the 'Warnings' header. warning_headers = response.headers.getall("warning", ()) self._raise_warnings(warning_headers) # raise errors based on http status codes, let the client handle those if needed if not (200 <= response.status < 300) and response.status not in ignore: self.log_request_fail( method, str(url), url_path, orig_body, duration, status_code=response.status, response=raw_data, ) self._raise_error(response.status, raw_data) self.log_request_success(method, str(url), url_path, orig_body, response.status, raw_data, duration) return response.status, response.headers, raw_data async def close(self): """ Explicitly closes connection """ if self.session: await self.session.close() async def _create_aiohttp_session(self): """Creates an aiohttp.ClientSession(). This is delayed until the first call to perform_request() so that AsyncTransport has a chance to set AIOHttpConnection.loop """ if self.loop is None: self.loop = get_running_loop() self.session = aiohttp.ClientSession( headers=self.headers, auto_decompress=True, loop=self.loop, cookie_jar=aiohttp.DummyCookieJar(), response_class=ESClientResponse, connector=aiohttp.TCPConnector(limit=self._limit, use_dns_cache=True, ssl=self._ssl_context), )
class S3Client: def __init__( self, session: ClientSession, url: URL, secret_access_key: str = None, access_key_id: str = None, region: str = "", executor: ThreadPoolExecutor = None, ): access_key_id = access_key_id or url.user secret_access_key = secret_access_key or url.password if not access_key_id: raise ValueError("access_key id must be passed as argument " "or as username in the url") if not secret_access_key: raise ValueError("secret_access_key id must be passed as argument " "or as username in the url") self._url = url.with_user(None).with_password(None) self._session = session self._executor = executor self._signer = AwsRequestSigner( region=region, service="s3", access_key_id=access_key_id, secret_access_key=secret_access_key, ) @property def url(self): return self._url def request(self, method: str, path: str, headers: LooseHeaders = None, params: ParamsType = None, data: t.Optional[DataType] = None, data_length: t.Optional[int] = None, content_sha256: str = None, **kwargs) -> RequestContextManager: if isinstance(data, bytes): data_length = len(data) elif isinstance(data, str): data = data.encode() data_length = len(data) headers = self._prepare_headers(headers) if data_length: headers[hdrs.CONTENT_LENGTH] = str(data_length) elif data is not None: headers[hdrs.CONTENT_ENCODING] = "chunked" if data is not None and content_sha256 is None: content_sha256 = UNSIGNED_PAYLOAD url = (self._url / path.lstrip('/')).with_query(params) url = url.with_path(quote(url.path), encoded=True) headers = self._make_headers(headers) headers.extend( self._signer.sign_with_headers(method, str(url), headers=headers, content_hash=content_sha256)) return self._session.request(method, url, headers=headers, data=data, **kwargs) def get(self, object_name: str, **kwargs) -> RequestContextManager: return self.request("GET", object_name, **kwargs) def head(self, object_name: str, **kwargs) -> RequestContextManager: return self.request("HEAD", object_name, **kwargs) def delete(self, object_name: str, **kwargs) -> RequestContextManager: return self.request("DELETE", object_name, **kwargs) @staticmethod def _make_headers(headers: t.Optional[LooseHeaders]) -> CIMultiDict: headers = CIMultiDict(headers or {}) return headers def _prepare_headers( self, headers: t.Optional[LooseHeaders], file_path: str = "", ) -> CIMultiDict: headers = self._make_headers(headers) if hdrs.CONTENT_TYPE not in headers: content_type = guess_type(file_path)[0] if content_type is None: content_type = "application/octet-stream" headers[hdrs.CONTENT_TYPE] = content_type return headers def put(self, object_name: str, data: t.Union[bytes, str, t.AsyncIterable[bytes]], **kwargs): return self.request("PUT", object_name, data=data, **kwargs) def put_file(self, object_name: t.Union[str, Path], file_path: t.Union[str, Path], *, headers: LooseHeaders = None, chunk_size: int = CHUNK_SIZE, content_sha256: str = None): headers = self._prepare_headers(headers, str(file_path)) return self.put( str(object_name), headers=headers, data=file_sender( file_path, executor=self._executor, chunk_size=chunk_size, ), data_length=os.stat(file_path).st_size, content_sha256=content_sha256, )
def main() -> None: # Demo content for our target file. content = b"Hello, World!\n" content_hash = hashlib.sha256(content).hexdigest() # Create a request signer instance. request_signer = AwsRequestSigner(AWS_REGION, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, "s3") # # Use AWS request signer to generate authentication headers. # # The headers we'll provide and want to sign. headers = { "Content-Type": "text/plain", "Content-Length": str(len(content)) } # Add the authentication headers. headers.update( request_signer.sign_with_headers("PUT", URL, headers, content_hash)) # Make the request. r = requests.put(URL, headers=headers, data=content) r.raise_for_status() # # Use AWS request signer to generate a pre-signed URL. # # The headers we'll provide and want to sign. headers = { "Content-Type": "text/plain", "Content-Length": str(len(content)) } # Generate the pre-signed URL that includes the authentication # parameters. Allow the client to determine the contents by # settings the content_has to UNSIGNED-PAYLOAD. presigned_url = request_signer.presign_url("PUT", URL, headers, UNSIGNED_PAYLOAD) # Perform the request. r = requests.put(presigned_url, headers=headers, data=content) r.raise_for_status() # # Use AWS request signer for requests helper to perform requests. # # Create a requests session and assign auth handler. session = requests.Session() session.auth = AuthHandler({ "http://127.0.0.1:9000": AwsAuth(AWS_REGION, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, "s3") }) # Perform the request. r = session.put(URL, data=content) r.raise_for_status() # # Use AWS request signer to sign an S3 POST policy request. # # Create a policy, only restricting bucket and expiration. expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=1) policy = { "expiration": expiration.strftime("%Y-%m-%dT%H:%M:%SZ"), "conditions": [{ "bucket": "demo" }], } # Get the required form fields to use the policy. fields = request_signer.sign_s3_post_policy(policy) # Post the form data to the bucket endpoint. # Set key (filename) to hello_world.txt. r = requests.post( URL.rsplit("/", 1)[0], data={ "key": "hello_world.txt", "Content-Type": "text/plain", **fields }, files={"file": content}, ) r.raise_for_status()