def test_get_canonical_host_with_non_default_port(self): """ If the port is not the default, the canonical host includes it. """ uri = "http://my.service:99/endpoint" endpoint = AWSServiceEndpoint(uri=uri) self.assertEquals("my.service:99", endpoint.get_canonical_host())
def test_get_canonical_host_is_lower_case(self): """ The canonical host is guaranteed to be lower case. """ uri = "http://MY.SerVice:99/endpoint" endpoint = AWSServiceEndpoint(uri=uri) self.assertEquals("my.service:99", endpoint.get_canonical_host())
def test_set_canonical_host(self): """ The canonical host is converted to lower case. """ endpoint = AWSServiceEndpoint() endpoint.set_canonical_host("My.Service") self.assertEquals("my.service", endpoint.host) self.assertIdentical(None, endpoint.port)
def test_set_canonical_host_with_port(self): """ The canonical host can optionally have a port. """ endpoint = AWSServiceEndpoint() endpoint.set_canonical_host("my.service:99") self.assertEquals("my.service", endpoint.host) self.assertEquals(99, endpoint.port)
def test_set_canonical_host_with_empty_port(self): """ The canonical host can also have no port. """ endpoint = AWSServiceEndpoint() endpoint.set_canonical_host("my.service:") self.assertEquals("my.service", endpoint.host) self.assertIdentical(None, endpoint.port)
def test_get_canonical_host(self): """ If the port is not specified the canonical host is the same as the host. """ uri = "http://my.service/endpoint" endpoint = AWSServiceEndpoint(uri=uri) self.assertEquals("my.service", endpoint.get_canonical_host())
class AWSServiceEndpointTestCase(TXAWSTestCase): def setUp(self): self.endpoint = AWSServiceEndpoint(uri="http://my.service/da_endpoint") def test_simple_creation(self): endpoint = AWSServiceEndpoint() self.assertEquals(endpoint.scheme, "http") self.assertEquals(endpoint.host, "") self.assertEquals(endpoint.port, 80) self.assertEquals(endpoint.path, "/") self.assertEquals(endpoint.method, "GET") def test_custom_method(self): endpoint = AWSServiceEndpoint( uri="http://service/endpoint", method="PUT") self.assertEquals(endpoint.method, "PUT") def test_parse_uri(self): self.assertEquals(self.endpoint.scheme, "http") self.assertEquals(self.endpoint.host, "my.service") self.assertEquals(self.endpoint.port, 80) self.assertEquals(self.endpoint.path, "/da_endpoint") def test_parse_uri_https_and_custom_port(self): endpoint = AWSServiceEndpoint(uri="https://my.service:8080/endpoint") self.assertEquals(endpoint.scheme, "https") self.assertEquals(endpoint.host, "my.service") self.assertEquals(endpoint.port, 8080) self.assertEquals(endpoint.path, "/endpoint") def test_get_uri(self): uri = self.endpoint.get_uri() self.assertEquals(uri, "http://my.service/da_endpoint") def test_get_uri_custom_port(self): uri = "https://my.service:8080/endpoint" endpoint = AWSServiceEndpoint(uri=uri) new_uri = endpoint.get_uri() self.assertEquals(new_uri, uri) def test_set_host(self): self.assertEquals(self.endpoint.host, "my.service") self.endpoint.set_host("newhost.com") self.assertEquals(self.endpoint.host, "newhost.com") def test_get_host(self): self.assertEquals(self.endpoint.host, self.endpoint.get_host()) def test_set_path(self): self.endpoint.set_path("/newpath") self.assertEquals( self.endpoint.get_uri(), "http://my.service/newpath") def test_set_method(self): self.assertEquals(self.endpoint.method, "GET") self.endpoint.set_method("PUT") self.assertEquals(self.endpoint.method, "PUT")
def __init__(self, bucket=None, object_name=None, data="", content_type=None, metadata={}, *args, **kwargs): super(Query, self).__init__(*args, **kwargs) self.bucket = bucket self.object_name = object_name self.data = data self.content_type = content_type self.metadata = metadata self.date = datetimeToString() if not self.endpoint or not self.endpoint.host: self.endpoint = AWSServiceEndpoint(S3_ENDPOINT) self.endpoint.set_method(self.action)
def get_queue(self, owner_id, queue): """ @param owner_id: required, C{str}. @param queue: required, C{str}: If owner_id and queue name is known, there is no need to do request for queue url. You should call this method to get queue and make operations on it. """ endpoint = AWSServiceEndpoint(uri=self.endpoint.get_uri()) endpoint.set_path('/{}/{}/'.format(owner_id, queue)) query_factory = QuerysSignatureV4(self.creds, endpoint, self.query_factory.agent) return Queue(self.creds, endpoint, query_factory)
class AWSServiceEndpointTestCase(TXAWSTestCase): def setUp(self): self.endpoint = AWSServiceEndpoint(uri="http://my.service/da_endpoint") def test_simple_creation(self): endpoint = AWSServiceEndpoint() self.assertEquals(endpoint.scheme, "http") self.assertEquals(endpoint.host, "") self.assertEquals(endpoint.port, 80) self.assertEquals(endpoint.path, "/") self.assertEquals(endpoint.method, "GET") def test_custom_method(self): endpoint = AWSServiceEndpoint(uri="http://service/endpoint", method="PUT") self.assertEquals(endpoint.method, "PUT") def test_parse_uri(self): self.assertEquals(self.endpoint.scheme, "http") self.assertEquals(self.endpoint.host, "my.service") self.assertEquals(self.endpoint.port, 80) self.assertEquals(self.endpoint.path, "/da_endpoint") def test_parse_uri_https_and_custom_port(self): endpoint = AWSServiceEndpoint(uri="https://my.service:8080/endpoint") self.assertEquals(endpoint.scheme, "https") self.assertEquals(endpoint.host, "my.service") self.assertEquals(endpoint.port, 8080) self.assertEquals(endpoint.path, "/endpoint") def test_get_uri(self): uri = self.endpoint.get_uri() self.assertEquals(uri, "http://my.service/da_endpoint") def test_get_uri_custom_port(self): uri = "https://my.service:8080/endpoint" endpoint = AWSServiceEndpoint(uri=uri) new_uri = endpoint.get_uri() self.assertEquals(new_uri, uri) def test_set_host(self): self.assertEquals(self.endpoint.host, "my.service") self.endpoint.set_host("newhost.com") self.assertEquals(self.endpoint.host, "newhost.com") def test_get_host(self): self.assertEquals(self.endpoint.host, self.endpoint.get_host()) def test_set_path(self): self.endpoint.set_path("/newpath") self.assertEquals(self.endpoint.get_uri(), "http://my.service/newpath") def test_set_method(self): self.assertEquals(self.endpoint.method, "GET") self.endpoint.set_method("PUT") self.assertEquals(self.endpoint.method, "PUT")
def test_ssl_verification_negative(self): """ The L{VerifyingContextFactory} fails with a SSL error the certificates can't be checked. """ context_factory = WebDefaultOpenSSLContextFactory( BADPRIVKEY, BADPUBKEY) self.port = reactor.listenSSL(0, self.site, context_factory, interface="127.0.0.1") self.portno = self.port.getHost().port endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) query = BaseQuery("an action", "creds", endpoint) d = query.get_page(self._get_url("file")) def fail(ignore): self.fail('Expected SSLError') def check_exception(why): # XXX kind of a mess here ... need to unwrap the # exception and check root_exc = why.value[0][0].value self.assert_(isinstance(root_exc, SSLError)) return d.addCallbacks(fail, check_exception)
def test_ssl_hostname_verification(self): """ If the endpoint passed to L{BaseQuery} has C{ssl_hostname_verification} sets to C{True}, a L{VerifyingContextFactory} is passed to C{connectSSL}. """ agent_creations = [] @implementer(IAgent) class FakeAgent(object): def __init__(self, reactor, contextFactory, connectTimeout=None, bindAddress=None, pool=None): agent_creations.append((reactor, contextFactory, connectTimeout, bindAddress, pool)) def request(self, method, uri, headers=None, bodyProducer=None): return Deferred() verifyClass(IAgent, FakeAgent) certs = [makeCertificate(O="Test Certificate", CN="something")[1]] self.patch(base, "Agent", FakeAgent) self.patch(ssl, "_ca_certs", certs) endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) query = BaseQuery("an action", "creds", endpoint, reactor="ignored") query.get_page("https://example.com/file") self.assertEqual(len(agent_creations), 1) [(_, contextFactory, _, _, _)] = agent_creations self.assertIsInstance(contextFactory, ssl.VerifyingContextFactory)
def test_handle_unicode_api_error(self): """ If an L{APIError} contains a unicode message, L{QueryAPI} is able to protect itself from it. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) def fail_execute(call): raise APIError(400, code="LangError", message=u"\N{HIRAGANA LETTER A}dvanced") self.api.execute = fail_execute def check(ignored): errors = self.flushLoggedErrors() self.assertEquals(0, len(errors)) self.assertTrue(request.finished) self.assertTrue(request.response.startswith("LangError")) self.assertEqual(400, request.code) self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check)
def test_handle_with_expired_signature(self): """ If the request contains an Expires parameter with a time that is before the current time, an error is returned. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint, other_params={"Expires": "2010-01-01T12:00:00Z"}) query.sign() request = FakeRequest(query.params, endpoint) def check(ignored): errors = self.flushLoggedErrors() self.assertEquals(0, len(errors)) self.assertEqual( "RequestExpired - Request has expired. Expires date is" " 2010-01-01T12:00:00Z", request.response) self.assertEqual(400, request.code) now = datetime(2010, 1, 1, 12, 0, 1, tzinfo=tzutc()) self.api.get_utc_time = lambda: now return self.api.handle(request).addCallback(check)
def test_handle_unicode_error(self): """ If an arbitrary error raised by an API method contains a unicode message, L{QueryAPI} is able to protect itself from it. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) def fail_execute(call): raise ValueError(u"\N{HIRAGANA LETTER A}dvanced") self.api.execute = fail_execute def check(ignored): [error] = self.flushLoggedErrors() self.assertIsInstance(error.value, ValueError) self.assertTrue(request.finished) self.assertEqual("Server error", request.response) self.assertEqual(500, request.code) self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check)
def __init__(self, access_key="", secret_key="", uri="", ec2_client_factory=None, keypairs=None, security_groups=None, instances=None, volumes=None, snapshots=None, availability_zones=None): self.access_key = access_key self.secret_key = secret_key self.uri = uri self.ec2_client = None if not ec2_client_factory: ec2_client_factory = FakeEC2Client self.ec2_client_factory = ec2_client_factory self.keypairs = keypairs self.security_groups = security_groups self.instances = instances self.volumes = volumes self.snapshots = snapshots self.availability_zones = availability_zones self.s3 = MemoryS3() self._creds = AWSCredentials( access_key=self.access_key, secret_key=self.secret_key, ) self._endpoint = AWSServiceEndpoint(uri=self.uri) self._route53_controller = MemoryRoute53()
def test_handle_500_api_error(self): """ If an L{APIError} is raised with a status code superior or equal to 500, the error is logged on the server side. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) def fail_execute(call): raise APIError(500, response="oops") self.api.execute = fail_execute def check(ignored): errors = self.flushLoggedErrors() self.assertEquals(1, len(errors)) self.assertTrue(request.finished) self.assertEqual("oops", request.response) self.assertEqual(500, request.code) self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check)
def test_simple_creation(self): endpoint = AWSServiceEndpoint() self.assertEquals(endpoint.scheme, "http") self.assertEquals(endpoint.host, "") self.assertEquals(endpoint.port, None) self.assertEquals(endpoint.path, "/") self.assertEquals(endpoint.method, "GET")
def test_handle_pass_params_to_call(self): """ L{QueryAPI.handle} creates a L{Call} object with the correct parameters. """ self.registry.add(TestMethod, "SomeAction", "1.2.3") creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint, other_params={ "Foo": "bar", "Version": "1.2.3" }) query.sign() request = FakeRequest(query.params, endpoint) def execute(call): self.assertEqual({"Foo": "bar"}, call.get_raw_params()) self.assertIdentical(self.api.principal, call.principal) self.assertEqual("SomeAction", call.action) self.assertEqual("1.2.3", call.version) self.assertEqual(request.id, call.id) return "ok" def check(ignored): self.assertEqual("ok", request.response) self.assertEqual(200, request.code) self.api.execute = execute self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check)
def test_handle_custom_get_call_arguments(self): """ L{QueryAPI.handle} uses L{QueryAPI.get_call_arguments} to get the arguments for a call. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") api = AlternativeWireFormatQueryAPI(self.registry) params = {"foo": "bar", "access_key": creds.access_key} signature = Signature(creds, endpoint, params.copy(), signature_method="Hmacsha256", signature_version=2) params["signature"] = signature.compute() request = FakeRequest(params, endpoint) def check(ignored): self.assertTrue(request.finished) self.assertEqual("data", request.response) self.assertEqual("4", request.headers["Content-Length"]) self.assertEqual("text/plain", request.headers["Content-Type"]) self.assertEqual(200, request.code) api.principal = TestPrincipal(creds) return api.handle(request).addCallback(check)
def test_get_uri_with_endpoint_bucket_and_object(self): endpoint = AWSServiceEndpoint("http://localhost/") url_context = client.URLContext( endpoint, bucket="mydocs", object_name="notes.txt") self.assertEquals( url_context.get_url(), "http://localhost/mydocs/notes.txt")
def test_handle_with_timestamp_and_expires(self): """ If the request contains both Expires and Timestamp parameters, an error is returned. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint, other_params={ "Timestamp": "2010-01-01T12:00:00Z", "Expires": "2010-01-01T12:00:00Z" }) query.sign() request = FakeRequest(query.params, endpoint) def check(ignored): errors = self.flushLoggedErrors() self.assertEquals(0, len(errors)) self.assertEqual( "InvalidParameterCombination - The parameter Timestamp" " cannot be used with the parameter Expires", request.response) self.assertEqual(400, request.code) return self.api.handle(request).addCallback(check)
def __init__(self, creds=None, endpoint=None, query_factory=None): if creds is None: creds = AWSCredentials() if endpoint is None: endpoint = AWSServiceEndpoint() self.creds = creds self.endpoint = endpoint self.query_factory = query_factory
class BucketURLContextTestCase(TXAWSTestCase): endpoint = AWSServiceEndpoint("https://s3.amazonaws.com/") def test_get_host_with_bucket(self): url_context = client.BucketURLContext(self.endpoint, "mystuff") self.assertEquals(url_context.get_host(), "s3.amazonaws.com") self.assertEquals(url_context.get_path(), "/mystuff")
def _validate_signature(self, request, principal, args, params): """Validate the signature.""" creds = AWSCredentials(principal.access_key, principal.secret_key) endpoint = AWSServiceEndpoint() endpoint.set_method(request.method) endpoint.set_canonical_host(request.getHeader("Host")) path = request.path if self.path is not None: path = "%s/%s" % (self.path.rstrip("/"), path.lstrip("/")) endpoint.set_path(path) signature = Signature( creds, endpoint, params, signature_method=args["signature_method"], signature_version=args["signature_version"], ) if signature.compute() != args["signature"]: raise APIError( 403, "SignatureDoesNotMatch", "The request signature we calculated does not " "match the signature you provided. Check your " "key and signing method.", )
def get_s3_client(self, creds=None): if creds is None: creds = AWSCredentials( access_key=self.access_key, secret_key=self.secret_key, ) endpoint = AWSServiceEndpoint(uri=self.uri) self.s3_client, self.s3_state = self.s3.client(creds, endpoint) return self.s3_client
def test_errors(self): query = BaseQuery( "an action", "creds", AWSServiceEndpoint("http://endpoint"), ) d = query.get_page(self._get_url("not_there")) self.assertFailure(d, TwistedWebError) return d
def get_route53_client(self, creds=None): if creds is None: creds = AWSCredentials( access_key=self.access_key, secret_key=self.secret_key, ) endpoint = AWSServiceEndpoint(uri=self.uri) client, state = self._route53_controller.client(creds, endpoint) return client
def test_nondefault_endpoint(self): lsc = LicenseServiceClient( creds=AWSCredentials(access_key=FAKE_AWS_ACCESS_KEY_ID, secret_key=FAKE_HMAC_KEY), endpoint=AWSServiceEndpoint( uri=PRODUCTION_LICENSE_SERVICE_ENDPOINT), ) self.failUnlessEqual(vars(lsc._endpoint), vars(self.lsc._endpoint))
def test_get_page(self): query = BaseQuery( "an action", "creds", AWSServiceEndpoint("http://endpoint"), ) d = query.get_page(self._get_url("file")) d.addCallback(self.assertEquals, "0123456789") return d
def __init__(self, creds=None, endpoint=None): if endpoint is None: endpoint = AWSServiceEndpoint(PRODUCTION_LICENSE_SERVICE_ENDPOINT) assert creds is None or isinstance(creds, AWSCredentials), `creds` assert isinstance(endpoint, AWSServiceEndpoint), `endpoint` self._creds = creds self._endpoint = endpoint
def get_EC2_properties(ec2accesskeyid, ec2secretkey, endpoint_uri, parser, *instance_ids): """ Reference: http://docs.amazonwebservices.com/AWSEC2/latest/APIReference/index.html?ApiReference-query-DescribeInstances.html """ ec2creds = AWSCredentials(ec2accesskeyid, ec2secretkey) endpoint = AWSServiceEndpoint(uri=endpoint_uri) client = EC2Client(creds=ec2creds, endpoint=endpoint, parser=parser) return client.describe_instances(*instance_ids)
def __init__(self, bucket=None, object_name=None, data="", content_type=None, metadata={}, amz_headers={}, body_producer=None, *args, **kwargs): super(Query, self).__init__(*args, **kwargs) # data might be None or "", alas. if data and body_producer is not None: raise ValueError("data and body_producer are mutually exclusive.") self.bucket = bucket self.object_name = object_name self.data = data self.body_producer = body_producer self.content_type = content_type self.metadata = metadata self.amz_headers = amz_headers self._date = datetimeToString() if not self.endpoint or not self.endpoint.host: self.endpoint = AWSServiceEndpoint(S3_ENDPOINT) self.endpoint.set_method(self.action)
class Query(BaseQuery): """A query for submission to the S3 service.""" def __init__(self, bucket=None, object_name=None, data="", content_type=None, metadata={}, amz_headers={}, *args, **kwargs): super(Query, self).__init__(*args, **kwargs) self.bucket = bucket self.object_name = object_name self.data = data self.content_type = content_type self.metadata = metadata self.amz_headers = amz_headers self.date = datetimeToString() if not self.endpoint or not self.endpoint.host: self.endpoint = AWSServiceEndpoint(S3_ENDPOINT) self.endpoint.set_method(self.action) def set_content_type(self): """ Set the content type based on the file extension used in the object name. """ if self.object_name and not self.content_type: # XXX nothing is currently done with the encoding... we may # need to in the future self.content_type, encoding = mimetypes.guess_type( self.object_name, strict=False) def get_headers(self): """ Build the list of headers needed in order to perform S3 operations. """ headers = {"Content-Length": len(self.data), "Content-MD5": calculate_md5(self.data), "Date": self.date} for key, value in self.metadata.iteritems(): headers["x-amz-meta-" + key] = value for key, values in self.amz_headers.iteritems(): if isinstance(values, tuple): headers["x-amz-" + key] = ",".join(values) else: headers["x-amz-" + key] = values # Before we check if the content type is set, let's see if we can set # it by guessing the the mimetype. self.set_content_type() if self.content_type is not None: headers["Content-Type"] = self.content_type if self.creds is not None: signature = self.sign(headers) headers["Authorization"] = "AWS %s:%s" % ( self.creds.access_key, signature) return headers def get_canonicalized_amz_headers(self, headers): """ Get the headers defined by Amazon S3. """ headers = [ (name.lower(), value) for name, value in headers.iteritems() if name.lower().startswith("x-amz-")] headers.sort() # XXX missing spec implementation: # txAWS doesn't currently unfold long headers def represent(n, vs): if isinstance(vs, tuple): return "".join(["%s:%s\n" % (n, vs) for v in vs]) else: return "%s:%s\n" % (n, vs) return "".join([represent(name, value) for name, value in headers]) def get_canonicalized_resource(self): """ Get an S3 resource path. """ # As <http://docs.amazonwebservices.com/AmazonS3/latest/dev/RESTAuthentication.html> # says, if there is a subresource (e.g. ?acl), it is included, but other query # parameters (e.g. ?prefix=... in a GET Bucket request) are not included. # Yes, that makes no sense in terms of either security or consistency. resource = self.object_name if resource: q = resource.find('?') if q >= 0: # There can be both a subresource and other parameters, for example # '?versions&prefix=foo'. "You are in a maze of twisty edge cases..." firstparam = resource[q:].partition('&')[0] # includes the initial '?' resource = resource[:q] # strip the query if '=' not in firstparam: resource += firstparam # add back '?subresource' if present path = "/" if self.bucket is not None: path += self.bucket if self.bucket is not None and resource: if not resource.startswith("/"): path += "/" path += resource elif self.bucket is not None and not path.endswith("/"): path += "/" return path def sign(self, headers): """Sign this query using its built in credentials.""" text = (self.action + "\n" + headers.get("Content-MD5", "") + "\n" + headers.get("Content-Type", "") + "\n" + headers.get("Date", "") + "\n" + self.get_canonicalized_amz_headers(headers) + self.get_canonicalized_resource()) return self.creds.sign(text, hash_type="sha1") def submit(self, url_context=None): """Submit this query. @return: A deferred from get_page """ if not url_context: url_context = URLContext( self.endpoint, self.bucket, self.object_name) d = self.get_page( url_context.get_url(), method=self.action, postdata=self.data, headers=self.get_headers()) return d.addErrback(s3_error_wrapper)
class Query(BaseQuery): """A query for submission to the S3 service.""" def __init__(self, bucket=None, object_name=None, data="", content_type=None, metadata={}, amz_headers={}, *args, **kwargs): super(Query, self).__init__(*args, **kwargs) self.bucket = bucket self.object_name = object_name self.data = data self.content_type = content_type self.metadata = metadata self.amz_headers = amz_headers self.date = datetimeToString() if not self.endpoint or not self.endpoint.host: self.endpoint = AWSServiceEndpoint(S3_ENDPOINT) self.endpoint.set_method(self.action) def set_content_type(self): """ Set the content type based on the file extension used in the object name. """ if self.object_name and not self.content_type: # XXX nothing is currently done with the encoding... we may # need to in the future self.content_type, encoding = mimetypes.guess_type( self.object_name, strict=False) def get_headers(self): """ Build the list of headers needed in order to perform S3 operations. """ headers = {"Content-Length": len(self.data), "Content-MD5": calculate_md5(self.data), "Date": self.date} for key, value in self.metadata.iteritems(): headers["x-amz-meta-" + key] = value for key, values in self.amz_headers.iteritems(): if isinstance(values, tuple): headers["x-amz-" + key] = ",".join(values) else: headers["x-amz-" + key] = values # Before we check if the content type is set, let's see if we can set # it by guessing the the mimetype. self.set_content_type() if self.content_type is not None: headers["Content-Type"] = self.content_type if self.creds is not None: signature = self.sign(headers) headers["Authorization"] = "AWS %s:%s" % ( self.creds.access_key, signature) return headers def get_canonicalized_amz_headers(self, headers): """ Get the headers defined by Amazon S3. """ headers = [ (name.lower(), value) for name, value in headers.iteritems() if name.lower().startswith("x-amz-")] headers.sort() # XXX missing spec implementation: # txAWS doesn't currently unfold long headers def represent(n, vs): if isinstance(vs, tuple): return "".join(["%s:%s\n" % (n, vs) for v in vs]) else: return "%s:%s\n" % (n, vs) return "".join([represent(name, value) for name, value in headers]) def get_canonicalized_resource(self): """ Get an S3 resource path. """ path = "/" if self.bucket is not None: path += self.bucket if self.bucket is not None and self.object_name: if not self.object_name.startswith("/"): path += "/" path += self.object_name elif self.bucket is not None and not path.endswith("/"): path += "/" return path def sign(self, headers): """Sign this query using its built in credentials.""" text = (self.action + "\n" + headers.get("Content-MD5", "") + "\n" + headers.get("Content-Type", "") + "\n" + headers.get("Date", "") + "\n" + self.get_canonicalized_amz_headers(headers) + self.get_canonicalized_resource()) return self.creds.sign(text, hash_type="sha1") def submit(self, url_context=None): """Submit this query. @return: A deferred from get_page """ if not url_context: url_context = URLContext( self.endpoint, self.bucket, self.object_name) d = self.get_page( url_context.get_url(), method=self.action, postdata=self.data, headers=self.get_headers()) return d.addErrback(s3_error_wrapper)
def test_get_uri_custom_port(self): uri = "https://my.service:8080/endpoint" endpoint = AWSServiceEndpoint(uri=uri) new_uri = endpoint.get_uri() self.assertEquals(new_uri, uri)
def setUp(self): self.endpoint = AWSServiceEndpoint(uri="http://my.service/da_endpoint")
class AWSServiceEndpointTestCase(TestCase): def setUp(self): self.endpoint = AWSServiceEndpoint(uri="http://my.service/da_endpoint") def test_warning_when_verification_disabled(self): """ L{AWSServiceEndpoint} emits a warning when told not to perform certificate verification. """ self.assertWarns( UserWarning, "Operating with certificate verification disabled!", __file__, lambda: AWSServiceEndpoint(ssl_hostname_verification=False), ) def test_simple_creation(self): endpoint = AWSServiceEndpoint() self.assertEquals(endpoint.scheme, "http") self.assertEquals(endpoint.host, "") self.assertEquals(endpoint.port, None) self.assertEquals(endpoint.path, "/") self.assertEquals(endpoint.method, "GET") def test_custom_method(self): endpoint = AWSServiceEndpoint( uri="http://service/endpoint", method="PUT") self.assertEquals(endpoint.method, "PUT") def test_parse_uri(self): self.assertEquals(self.endpoint.scheme, "http") self.assertEquals(self.endpoint.host, "my.service") self.assertIdentical(self.endpoint.port, None) self.assertEquals(self.endpoint.path, "/da_endpoint") def test_parse_uri_https_and_custom_port(self): endpoint = AWSServiceEndpoint(uri="https://my.service:8080/endpoint") self.assertEquals(endpoint.scheme, "https") self.assertEquals(endpoint.host, "my.service") self.assertEquals(endpoint.port, 8080) self.assertEquals(endpoint.path, "/endpoint") def test_get_uri(self): uri = self.endpoint.get_uri() self.assertEquals(uri, "http://my.service/da_endpoint") def test_get_uri_custom_port(self): uri = "https://my.service:8080/endpoint" endpoint = AWSServiceEndpoint(uri=uri) new_uri = endpoint.get_uri() self.assertEquals(new_uri, uri) def test_set_host(self): self.assertEquals(self.endpoint.host, "my.service") self.endpoint.set_host("newhost.com") self.assertEquals(self.endpoint.host, "newhost.com") def test_get_host(self): self.assertEquals(self.endpoint.host, self.endpoint.get_host()) def test_get_canonical_host(self): """ If the port is not specified the canonical host is the same as the host. """ uri = "http://my.service/endpoint" endpoint = AWSServiceEndpoint(uri=uri) self.assertEquals("my.service", endpoint.get_canonical_host()) def test_get_canonical_host_with_non_default_port(self): """ If the port is not the default, the canonical host includes it. """ uri = "http://my.service:99/endpoint" endpoint = AWSServiceEndpoint(uri=uri) self.assertEquals("my.service:99", endpoint.get_canonical_host()) def test_get_canonical_host_is_lower_case(self): """ The canonical host is guaranteed to be lower case. """ uri = "http://MY.SerVice:99/endpoint" endpoint = AWSServiceEndpoint(uri=uri) self.assertEquals("my.service:99", endpoint.get_canonical_host()) def test_set_canonical_host(self): """ The canonical host is converted to lower case. """ endpoint = AWSServiceEndpoint() endpoint.set_canonical_host("My.Service") self.assertEquals("my.service", endpoint.host) self.assertIdentical(None, endpoint.port) def test_set_canonical_host_with_port(self): """ The canonical host can optionally have a port. """ endpoint = AWSServiceEndpoint() endpoint.set_canonical_host("my.service:99") self.assertEquals("my.service", endpoint.host) self.assertEquals(99, endpoint.port) def test_set_canonical_host_with_empty_port(self): """ The canonical host can also have no port. """ endpoint = AWSServiceEndpoint() endpoint.set_canonical_host("my.service:") self.assertEquals("my.service", endpoint.host) self.assertIdentical(None, endpoint.port) def test_set_path(self): self.endpoint.set_path("/newpath") self.assertEquals( self.endpoint.get_uri(), "http://my.service/newpath") def test_set_method(self): self.assertEquals(self.endpoint.method, "GET") self.endpoint.set_method("PUT") self.assertEquals(self.endpoint.method, "PUT")
class Query(BaseQuery): """A query for submission to the S3 service.""" def __init__(self, bucket=None, object_name=None, data="", content_type=None, metadata={}, amz_headers={}, body_producer=None, *args, **kwargs): super(Query, self).__init__(*args, **kwargs) # data might be None or "", alas. if data and body_producer is not None: raise ValueError("data and body_producer are mutually exclusive.") self.bucket = bucket self.object_name = object_name self.data = data self.body_producer = body_producer self.content_type = content_type self.metadata = metadata self.amz_headers = amz_headers self._date = datetimeToString() if not self.endpoint or not self.endpoint.host: self.endpoint = AWSServiceEndpoint(S3_ENDPOINT) self.endpoint.set_method(self.action) @property def date(self): """ Return the date and emit a deprecation warning. """ warnings.warn("txaws.s3.client.Query.date is a deprecated attribute", DeprecationWarning, stacklevel=2) return self._date @date.setter def date(self, value): """ Set the date. @param value: The new date for this L{Query}. @type value: L{str} """ self._date = value def set_content_type(self): """ Set the content type based on the file extension used in the object name. """ if self.object_name and not self.content_type: # XXX nothing is currently done with the encoding... we may # need to in the future self.content_type, encoding = mimetypes.guess_type( self.object_name, strict=False) def get_headers(self, instant): """ Build the list of headers needed in order to perform S3 operations. """ headers = {'x-amz-date': _auth_v4.makeAMZDate(instant)} if self.body_producer is None: data = self.data if data is None: data = b"" headers["x-amz-content-sha256"] = hashlib.sha256(data).hexdigest() else: data = None headers["x-amz-content-sha256"] = b"UNSIGNED-PAYLOAD" for key, value in self.metadata.iteritems(): headers["x-amz-meta-" + key] = value for key, value in self.amz_headers.iteritems(): headers["x-amz-" + key] = value # Before we check if the content type is set, let's see if we can set # it by guessing the the mimetype. self.set_content_type() if self.content_type is not None: headers["Content-Type"] = self.content_type if self.creds is not None: headers["Authorization"] = self.sign( headers, data, s3_url_context(self.endpoint, self.bucket, self.object_name), instant, method=self.action) return headers def sign(self, headers, data, url_context, instant, method, region=REGION_US_EAST_1): """Sign this query using its built in credentials.""" headers["host"] = url_context.get_encoded_host() if data is None: request = _auth_v4._CanonicalRequest.from_request_components( method=method, url=url_context.get_encoded_path(), headers=headers, headers_to_sign=('host', 'x-amz-date'), payload_hash=None, ) else: request = _auth_v4._CanonicalRequest.from_request_components_and_payload( method=method, url=url_context.get_encoded_path(), headers=headers, headers_to_sign=('host', 'x-amz-date'), payload=data, ) return _auth_v4._make_authorization_header( region=region, service="s3", canonical_request=request, credentials=self.creds, instant=instant) def submit(self, url_context=None, utcnow=datetime.datetime.utcnow): """Submit this query. @return: A deferred from get_page """ if not url_context: url_context = s3_url_context( self.endpoint, self.bucket, self.object_name) d = self.get_page( url_context.get_encoded_url(), method=self.action, postdata=self.data or b"", headers=self.get_headers(utcnow()), ) return d.addErrback(s3_error_wrapper)