Example #1
0
 def test_get_canonical_host_with_non_default_port(self):
     """
     If the port is not the default, the canonical host includes it.
     """
     uri = "http://my.service:99/endpoint"
     endpoint = AWSServiceEndpoint(uri=uri)
     self.assertEquals("my.service:99", endpoint.get_canonical_host())
Example #2
0
 def test_get_canonical_host_is_lower_case(self):
     """
     The canonical host is guaranteed to be lower case.
     """
     uri = "http://MY.SerVice:99/endpoint"
     endpoint = AWSServiceEndpoint(uri=uri)
     self.assertEquals("my.service:99", endpoint.get_canonical_host())
Example #3
0
 def test_set_canonical_host(self):
     """
     The canonical host is converted to lower case.
     """
     endpoint = AWSServiceEndpoint()
     endpoint.set_canonical_host("My.Service")
     self.assertEquals("my.service", endpoint.host)
     self.assertIdentical(None, endpoint.port)
Example #4
0
 def test_set_canonical_host_with_port(self):
     """
     The canonical host can optionally have a port.
     """
     endpoint = AWSServiceEndpoint()
     endpoint.set_canonical_host("my.service:99")
     self.assertEquals("my.service", endpoint.host)
     self.assertEquals(99, endpoint.port)
Example #5
0
 def test_set_canonical_host_with_empty_port(self):
     """
     The canonical host can also have no port.
     """
     endpoint = AWSServiceEndpoint()
     endpoint.set_canonical_host("my.service:")
     self.assertEquals("my.service", endpoint.host)
     self.assertIdentical(None, endpoint.port)
Example #6
0
 def test_get_canonical_host(self):
     """
     If the port is not specified the canonical host is the same as
     the host.
     """
     uri = "http://my.service/endpoint"
     endpoint = AWSServiceEndpoint(uri=uri)
     self.assertEquals("my.service", endpoint.get_canonical_host())
Example #7
0
class AWSServiceEndpointTestCase(TXAWSTestCase):

    def setUp(self):
        self.endpoint = AWSServiceEndpoint(uri="http://my.service/da_endpoint")

    def test_simple_creation(self):
        endpoint = AWSServiceEndpoint()
        self.assertEquals(endpoint.scheme, "http")
        self.assertEquals(endpoint.host, "")
        self.assertEquals(endpoint.port, 80)
        self.assertEquals(endpoint.path, "/")
        self.assertEquals(endpoint.method, "GET")

    def test_custom_method(self):
        endpoint = AWSServiceEndpoint(
            uri="http://service/endpoint", method="PUT")
        self.assertEquals(endpoint.method, "PUT")

    def test_parse_uri(self):
        self.assertEquals(self.endpoint.scheme, "http")
        self.assertEquals(self.endpoint.host, "my.service")
        self.assertEquals(self.endpoint.port, 80)
        self.assertEquals(self.endpoint.path, "/da_endpoint")

    def test_parse_uri_https_and_custom_port(self):
        endpoint = AWSServiceEndpoint(uri="https://my.service:8080/endpoint")
        self.assertEquals(endpoint.scheme, "https")
        self.assertEquals(endpoint.host, "my.service")
        self.assertEquals(endpoint.port, 8080)
        self.assertEquals(endpoint.path, "/endpoint")

    def test_get_uri(self):
        uri = self.endpoint.get_uri()
        self.assertEquals(uri, "http://my.service/da_endpoint")

    def test_get_uri_custom_port(self):
        uri = "https://my.service:8080/endpoint"
        endpoint = AWSServiceEndpoint(uri=uri)
        new_uri = endpoint.get_uri()
        self.assertEquals(new_uri, uri)

    def test_set_host(self):
        self.assertEquals(self.endpoint.host, "my.service")
        self.endpoint.set_host("newhost.com")
        self.assertEquals(self.endpoint.host, "newhost.com")

    def test_get_host(self):
        self.assertEquals(self.endpoint.host, self.endpoint.get_host())

    def test_set_path(self):
        self.endpoint.set_path("/newpath")
        self.assertEquals(
            self.endpoint.get_uri(),
            "http://my.service/newpath")

    def test_set_method(self):
        self.assertEquals(self.endpoint.method, "GET")
        self.endpoint.set_method("PUT")
        self.assertEquals(self.endpoint.method, "PUT")
Example #8
0
 def __init__(self, bucket=None, object_name=None, data="",
              content_type=None, metadata={}, *args, **kwargs):
     super(Query, self).__init__(*args, **kwargs)
     self.bucket = bucket
     self.object_name = object_name
     self.data = data
     self.content_type = content_type
     self.metadata = metadata
     self.date = datetimeToString()
     if not self.endpoint or not self.endpoint.host:
         self.endpoint = AWSServiceEndpoint(S3_ENDPOINT)
     self.endpoint.set_method(self.action)
Example #9
0
 def get_queue(self, owner_id, queue):
     """
         @param owner_id: required, C{str}.
         @param queue: required, C{str}:
         If owner_id and queue name is known, there is no need to do
         request for queue url. You should call this method to get queue
         and make operations on it.
     """
     endpoint = AWSServiceEndpoint(uri=self.endpoint.get_uri())
     endpoint.set_path('/{}/{}/'.format(owner_id, queue))
     query_factory = QuerysSignatureV4(self.creds, endpoint,
                                       self.query_factory.agent)
     return Queue(self.creds, endpoint, query_factory)
Example #10
0
 def get_queue(self, owner_id, queue):
     """
         @param owner_id: required, C{str}.
         @param queue: required, C{str}:
         If owner_id and queue name is known, there is no need to do
         request for queue url. You should call this method to get queue
         and make operations on it.
     """
     endpoint = AWSServiceEndpoint(uri=self.endpoint.get_uri())
     endpoint.set_path('/{}/{}/'.format(owner_id, queue))
     query_factory = QuerysSignatureV4(self.creds, endpoint,
                                       self.query_factory.agent)
     return Queue(self.creds, endpoint, query_factory)
Example #11
0
class AWSServiceEndpointTestCase(TXAWSTestCase):
    def setUp(self):
        self.endpoint = AWSServiceEndpoint(uri="http://my.service/da_endpoint")

    def test_simple_creation(self):
        endpoint = AWSServiceEndpoint()
        self.assertEquals(endpoint.scheme, "http")
        self.assertEquals(endpoint.host, "")
        self.assertEquals(endpoint.port, 80)
        self.assertEquals(endpoint.path, "/")
        self.assertEquals(endpoint.method, "GET")

    def test_custom_method(self):
        endpoint = AWSServiceEndpoint(uri="http://service/endpoint",
                                      method="PUT")
        self.assertEquals(endpoint.method, "PUT")

    def test_parse_uri(self):
        self.assertEquals(self.endpoint.scheme, "http")
        self.assertEquals(self.endpoint.host, "my.service")
        self.assertEquals(self.endpoint.port, 80)
        self.assertEquals(self.endpoint.path, "/da_endpoint")

    def test_parse_uri_https_and_custom_port(self):
        endpoint = AWSServiceEndpoint(uri="https://my.service:8080/endpoint")
        self.assertEquals(endpoint.scheme, "https")
        self.assertEquals(endpoint.host, "my.service")
        self.assertEquals(endpoint.port, 8080)
        self.assertEquals(endpoint.path, "/endpoint")

    def test_get_uri(self):
        uri = self.endpoint.get_uri()
        self.assertEquals(uri, "http://my.service/da_endpoint")

    def test_get_uri_custom_port(self):
        uri = "https://my.service:8080/endpoint"
        endpoint = AWSServiceEndpoint(uri=uri)
        new_uri = endpoint.get_uri()
        self.assertEquals(new_uri, uri)

    def test_set_host(self):
        self.assertEquals(self.endpoint.host, "my.service")
        self.endpoint.set_host("newhost.com")
        self.assertEquals(self.endpoint.host, "newhost.com")

    def test_get_host(self):
        self.assertEquals(self.endpoint.host, self.endpoint.get_host())

    def test_set_path(self):
        self.endpoint.set_path("/newpath")
        self.assertEquals(self.endpoint.get_uri(), "http://my.service/newpath")

    def test_set_method(self):
        self.assertEquals(self.endpoint.method, "GET")
        self.endpoint.set_method("PUT")
        self.assertEquals(self.endpoint.method, "PUT")
Example #12
0
    def test_ssl_verification_negative(self):
        """
        The L{VerifyingContextFactory} fails with a SSL error the certificates
        can't be checked.
        """
        context_factory = WebDefaultOpenSSLContextFactory(
            BADPRIVKEY, BADPUBKEY)
        self.port = reactor.listenSSL(0,
                                      self.site,
                                      context_factory,
                                      interface="127.0.0.1")
        self.portno = self.port.getHost().port

        endpoint = AWSServiceEndpoint(ssl_hostname_verification=True)
        query = BaseQuery("an action", "creds", endpoint)
        d = query.get_page(self._get_url("file"))

        def fail(ignore):
            self.fail('Expected SSLError')

        def check_exception(why):
            # XXX kind of a mess here ... need to unwrap the
            # exception and check
            root_exc = why.value[0][0].value
            self.assert_(isinstance(root_exc, SSLError))

        return d.addCallbacks(fail, check_exception)
Example #13
0
    def test_ssl_hostname_verification(self):
        """
        If the endpoint passed to L{BaseQuery} has C{ssl_hostname_verification}
        sets to C{True}, a L{VerifyingContextFactory} is passed to
        C{connectSSL}.
        """
        agent_creations = []

        @implementer(IAgent)
        class FakeAgent(object):
            def __init__(self,
                         reactor,
                         contextFactory,
                         connectTimeout=None,
                         bindAddress=None,
                         pool=None):
                agent_creations.append((reactor, contextFactory,
                                        connectTimeout, bindAddress, pool))

            def request(self, method, uri, headers=None, bodyProducer=None):
                return Deferred()

        verifyClass(IAgent, FakeAgent)

        certs = [makeCertificate(O="Test Certificate", CN="something")[1]]
        self.patch(base, "Agent", FakeAgent)
        self.patch(ssl, "_ca_certs", certs)
        endpoint = AWSServiceEndpoint(ssl_hostname_verification=True)
        query = BaseQuery("an action", "creds", endpoint, reactor="ignored")
        query.get_page("https://example.com/file")

        self.assertEqual(len(agent_creations), 1)
        [(_, contextFactory, _, _, _)] = agent_creations
        self.assertIsInstance(contextFactory, ssl.VerifyingContextFactory)
Example #14
0
    def test_handle_unicode_api_error(self):
        """
        If an L{APIError} contains a unicode message, L{QueryAPI} is able to
        protect itself from it.
        """
        creds = AWSCredentials("access", "secret")
        endpoint = AWSServiceEndpoint("http://uri")
        query = Query(action="SomeAction", creds=creds, endpoint=endpoint)
        query.sign()
        request = FakeRequest(query.params, endpoint)

        def fail_execute(call):
            raise APIError(400,
                           code="LangError",
                           message=u"\N{HIRAGANA LETTER A}dvanced")

        self.api.execute = fail_execute

        def check(ignored):
            errors = self.flushLoggedErrors()
            self.assertEquals(0, len(errors))
            self.assertTrue(request.finished)
            self.assertTrue(request.response.startswith("LangError"))
            self.assertEqual(400, request.code)

        self.api.principal = TestPrincipal(creds)
        return self.api.handle(request).addCallback(check)
Example #15
0
    def test_handle_with_expired_signature(self):
        """
        If the request contains an Expires parameter with a time that is before
        the current time, an error is returned.
        """
        creds = AWSCredentials("access", "secret")
        endpoint = AWSServiceEndpoint("http://uri")
        query = Query(action="SomeAction",
                      creds=creds,
                      endpoint=endpoint,
                      other_params={"Expires": "2010-01-01T12:00:00Z"})
        query.sign()
        request = FakeRequest(query.params, endpoint)

        def check(ignored):
            errors = self.flushLoggedErrors()
            self.assertEquals(0, len(errors))
            self.assertEqual(
                "RequestExpired - Request has expired. Expires date is"
                " 2010-01-01T12:00:00Z", request.response)
            self.assertEqual(400, request.code)

        now = datetime(2010, 1, 1, 12, 0, 1, tzinfo=tzutc())
        self.api.get_utc_time = lambda: now
        return self.api.handle(request).addCallback(check)
Example #16
0
    def test_handle_unicode_error(self):
        """
        If an arbitrary error raised by an API method contains a unicode
        message, L{QueryAPI} is able to protect itself from it.
        """
        creds = AWSCredentials("access", "secret")
        endpoint = AWSServiceEndpoint("http://uri")
        query = Query(action="SomeAction", creds=creds, endpoint=endpoint)
        query.sign()
        request = FakeRequest(query.params, endpoint)

        def fail_execute(call):
            raise ValueError(u"\N{HIRAGANA LETTER A}dvanced")

        self.api.execute = fail_execute

        def check(ignored):
            [error] = self.flushLoggedErrors()
            self.assertIsInstance(error.value, ValueError)
            self.assertTrue(request.finished)
            self.assertEqual("Server error", request.response)
            self.assertEqual(500, request.code)

        self.api.principal = TestPrincipal(creds)
        return self.api.handle(request).addCallback(check)
Example #17
0
    def __init__(self,
                 access_key="",
                 secret_key="",
                 uri="",
                 ec2_client_factory=None,
                 keypairs=None,
                 security_groups=None,
                 instances=None,
                 volumes=None,
                 snapshots=None,
                 availability_zones=None):
        self.access_key = access_key
        self.secret_key = secret_key
        self.uri = uri
        self.ec2_client = None
        if not ec2_client_factory:
            ec2_client_factory = FakeEC2Client
        self.ec2_client_factory = ec2_client_factory
        self.keypairs = keypairs
        self.security_groups = security_groups
        self.instances = instances
        self.volumes = volumes
        self.snapshots = snapshots
        self.availability_zones = availability_zones
        self.s3 = MemoryS3()

        self._creds = AWSCredentials(
            access_key=self.access_key,
            secret_key=self.secret_key,
        )
        self._endpoint = AWSServiceEndpoint(uri=self.uri)
        self._route53_controller = MemoryRoute53()
Example #18
0
    def test_handle_500_api_error(self):
        """
        If an L{APIError} is raised with a status code superior or equal to
        500, the error is logged on the server side.
        """
        creds = AWSCredentials("access", "secret")
        endpoint = AWSServiceEndpoint("http://uri")
        query = Query(action="SomeAction", creds=creds, endpoint=endpoint)
        query.sign()
        request = FakeRequest(query.params, endpoint)

        def fail_execute(call):
            raise APIError(500, response="oops")

        self.api.execute = fail_execute

        def check(ignored):
            errors = self.flushLoggedErrors()
            self.assertEquals(1, len(errors))
            self.assertTrue(request.finished)
            self.assertEqual("oops", request.response)
            self.assertEqual(500, request.code)

        self.api.principal = TestPrincipal(creds)
        return self.api.handle(request).addCallback(check)
Example #19
0
 def test_simple_creation(self):
     endpoint = AWSServiceEndpoint()
     self.assertEquals(endpoint.scheme, "http")
     self.assertEquals(endpoint.host, "")
     self.assertEquals(endpoint.port, None)
     self.assertEquals(endpoint.path, "/")
     self.assertEquals(endpoint.method, "GET")
Example #20
0
    def test_handle_pass_params_to_call(self):
        """
        L{QueryAPI.handle} creates a L{Call} object with the correct
        parameters.
        """
        self.registry.add(TestMethod, "SomeAction", "1.2.3")
        creds = AWSCredentials("access", "secret")
        endpoint = AWSServiceEndpoint("http://uri")
        query = Query(action="SomeAction",
                      creds=creds,
                      endpoint=endpoint,
                      other_params={
                          "Foo": "bar",
                          "Version": "1.2.3"
                      })
        query.sign()
        request = FakeRequest(query.params, endpoint)

        def execute(call):
            self.assertEqual({"Foo": "bar"}, call.get_raw_params())
            self.assertIdentical(self.api.principal, call.principal)
            self.assertEqual("SomeAction", call.action)
            self.assertEqual("1.2.3", call.version)
            self.assertEqual(request.id, call.id)
            return "ok"

        def check(ignored):
            self.assertEqual("ok", request.response)
            self.assertEqual(200, request.code)

        self.api.execute = execute
        self.api.principal = TestPrincipal(creds)
        return self.api.handle(request).addCallback(check)
Example #21
0
    def test_handle_custom_get_call_arguments(self):
        """
        L{QueryAPI.handle} uses L{QueryAPI.get_call_arguments} to get the
        arguments for a call.
        """
        creds = AWSCredentials("access", "secret")
        endpoint = AWSServiceEndpoint("http://uri")
        api = AlternativeWireFormatQueryAPI(self.registry)
        params = {"foo": "bar", "access_key": creds.access_key}
        signature = Signature(creds,
                              endpoint,
                              params.copy(),
                              signature_method="Hmacsha256",
                              signature_version=2)
        params["signature"] = signature.compute()
        request = FakeRequest(params, endpoint)

        def check(ignored):
            self.assertTrue(request.finished)
            self.assertEqual("data", request.response)
            self.assertEqual("4", request.headers["Content-Length"])
            self.assertEqual("text/plain", request.headers["Content-Type"])
            self.assertEqual(200, request.code)

        api.principal = TestPrincipal(creds)
        return api.handle(request).addCallback(check)
Example #22
0
 def test_get_uri_with_endpoint_bucket_and_object(self):
     endpoint = AWSServiceEndpoint("http://localhost/")
     url_context = client.URLContext(
         endpoint, bucket="mydocs", object_name="notes.txt")
     self.assertEquals(
         url_context.get_url(),
         "http://localhost/mydocs/notes.txt")
Example #23
0
    def test_handle_with_timestamp_and_expires(self):
        """
        If the request contains both Expires and Timestamp parameters,
        an error is returned.
        """
        creds = AWSCredentials("access", "secret")
        endpoint = AWSServiceEndpoint("http://uri")
        query = Query(action="SomeAction",
                      creds=creds,
                      endpoint=endpoint,
                      other_params={
                          "Timestamp": "2010-01-01T12:00:00Z",
                          "Expires": "2010-01-01T12:00:00Z"
                      })
        query.sign()
        request = FakeRequest(query.params, endpoint)

        def check(ignored):
            errors = self.flushLoggedErrors()
            self.assertEquals(0, len(errors))
            self.assertEqual(
                "InvalidParameterCombination - The parameter Timestamp"
                " cannot be used with the parameter Expires", request.response)
            self.assertEqual(400, request.code)

        return self.api.handle(request).addCallback(check)
Example #24
0
File: base.py Project: lzimm/360io
 def __init__(self, creds=None, endpoint=None, query_factory=None):
     if creds is None:
         creds = AWSCredentials()
     if endpoint is None:
         endpoint = AWSServiceEndpoint()
     self.creds = creds
     self.endpoint = endpoint
     self.query_factory = query_factory
Example #25
0
class BucketURLContextTestCase(TXAWSTestCase):

    endpoint = AWSServiceEndpoint("https://s3.amazonaws.com/")

    def test_get_host_with_bucket(self):
        url_context = client.BucketURLContext(self.endpoint, "mystuff")
        self.assertEquals(url_context.get_host(), "s3.amazonaws.com")
        self.assertEquals(url_context.get_path(), "/mystuff")
Example #26
0
 def _validate_signature(self, request, principal, args, params):
     """Validate the signature."""
     creds = AWSCredentials(principal.access_key, principal.secret_key)
     endpoint = AWSServiceEndpoint()
     endpoint.set_method(request.method)
     endpoint.set_canonical_host(request.getHeader("Host"))
     path = request.path
     if self.path is not None:
         path = "%s/%s" % (self.path.rstrip("/"), path.lstrip("/"))
     endpoint.set_path(path)
     signature = Signature(
         creds,
         endpoint,
         params,
         signature_method=args["signature_method"],
         signature_version=args["signature_version"],
     )
     if signature.compute() != args["signature"]:
         raise APIError(
             403,
             "SignatureDoesNotMatch",
             "The request signature we calculated does not "
             "match the signature you provided. Check your "
             "key and signing method.",
         )
Example #27
0
 def get_s3_client(self, creds=None):
     if creds is None:
         creds = AWSCredentials(
             access_key=self.access_key,
             secret_key=self.secret_key,
         )
     endpoint = AWSServiceEndpoint(uri=self.uri)
     self.s3_client, self.s3_state = self.s3.client(creds, endpoint)
     return self.s3_client
Example #28
0
 def test_errors(self):
     query = BaseQuery(
         "an action",
         "creds",
         AWSServiceEndpoint("http://endpoint"),
     )
     d = query.get_page(self._get_url("not_there"))
     self.assertFailure(d, TwistedWebError)
     return d
Example #29
0
 def get_route53_client(self, creds=None):
     if creds is None:
         creds = AWSCredentials(
             access_key=self.access_key,
             secret_key=self.secret_key,
         )
     endpoint = AWSServiceEndpoint(uri=self.uri)
     client, state = self._route53_controller.client(creds, endpoint)
     return client
Example #30
0
    def test_nondefault_endpoint(self):
        lsc = LicenseServiceClient(
            creds=AWSCredentials(access_key=FAKE_AWS_ACCESS_KEY_ID,
                                 secret_key=FAKE_HMAC_KEY),
            endpoint=AWSServiceEndpoint(
                uri=PRODUCTION_LICENSE_SERVICE_ENDPOINT),
        )

        self.failUnlessEqual(vars(lsc._endpoint), vars(self.lsc._endpoint))
Example #31
0
 def test_get_page(self):
     query = BaseQuery(
         "an action",
         "creds",
         AWSServiceEndpoint("http://endpoint"),
     )
     d = query.get_page(self._get_url("file"))
     d.addCallback(self.assertEquals, "0123456789")
     return d
    def __init__(self, creds=None, endpoint=None):
        if endpoint is None:
            endpoint = AWSServiceEndpoint(PRODUCTION_LICENSE_SERVICE_ENDPOINT)

        assert creds is None or isinstance(creds, AWSCredentials), `creds`
        assert isinstance(endpoint, AWSServiceEndpoint), `endpoint`

        self._creds = creds
        self._endpoint = endpoint
Example #33
0
def get_EC2_properties(ec2accesskeyid, ec2secretkey, endpoint_uri, parser,
                       *instance_ids):
    """
    Reference: http://docs.amazonwebservices.com/AWSEC2/latest/APIReference/index.html?ApiReference-query-DescribeInstances.html
    """
    ec2creds = AWSCredentials(ec2accesskeyid, ec2secretkey)
    endpoint = AWSServiceEndpoint(uri=endpoint_uri)
    client = EC2Client(creds=ec2creds, endpoint=endpoint, parser=parser)
    return client.describe_instances(*instance_ids)
Example #34
0
    def __init__(self, bucket=None, object_name=None, data="",
                 content_type=None, metadata={}, amz_headers={},
                 body_producer=None, *args, **kwargs):
        super(Query, self).__init__(*args, **kwargs)

        # data might be None or "", alas.
        if data and body_producer is not None:
            raise ValueError("data and body_producer are mutually exclusive.")

        self.bucket = bucket
        self.object_name = object_name
        self.data = data
        self.body_producer = body_producer
        self.content_type = content_type
        self.metadata = metadata
        self.amz_headers = amz_headers
        self._date = datetimeToString()
        if not self.endpoint or not self.endpoint.host:
            self.endpoint = AWSServiceEndpoint(S3_ENDPOINT)
        self.endpoint.set_method(self.action)
Example #35
0
class Query(BaseQuery):
    """A query for submission to the S3 service."""

    def __init__(self, bucket=None, object_name=None, data="",
                 content_type=None, metadata={}, amz_headers={}, *args,
                 **kwargs):
        super(Query, self).__init__(*args, **kwargs)
        self.bucket = bucket
        self.object_name = object_name
        self.data = data
        self.content_type = content_type
        self.metadata = metadata
        self.amz_headers = amz_headers
        self.date = datetimeToString()
        if not self.endpoint or not self.endpoint.host:
            self.endpoint = AWSServiceEndpoint(S3_ENDPOINT)
        self.endpoint.set_method(self.action)

    def set_content_type(self):
        """
        Set the content type based on the file extension used in the object
        name.
        """
        if self.object_name and not self.content_type:
            # XXX nothing is currently done with the encoding... we may
            # need to in the future
            self.content_type, encoding = mimetypes.guess_type(
                self.object_name, strict=False)

    def get_headers(self):
        """
        Build the list of headers needed in order to perform S3 operations.
        """
        headers = {"Content-Length": len(self.data),
                   "Content-MD5": calculate_md5(self.data),
                   "Date": self.date}
        for key, value in self.metadata.iteritems():
            headers["x-amz-meta-" + key] = value
        for key, values in self.amz_headers.iteritems():
            if isinstance(values, tuple):
                headers["x-amz-" + key] = ",".join(values)
            else:
                headers["x-amz-" + key] = values
        # Before we check if the content type is set, let's see if we can set
        # it by guessing the the mimetype.
        self.set_content_type()
        if self.content_type is not None:
            headers["Content-Type"] = self.content_type
        if self.creds is not None:
            signature = self.sign(headers)
            headers["Authorization"] = "AWS %s:%s" % (
                self.creds.access_key, signature)
        return headers

    def get_canonicalized_amz_headers(self, headers):
        """
        Get the headers defined by Amazon S3.
        """
        headers = [
            (name.lower(), value) for name, value in headers.iteritems()
            if name.lower().startswith("x-amz-")]
        headers.sort()
        # XXX missing spec implementation:
        # txAWS doesn't currently unfold long headers
        def represent(n, vs):
            if isinstance(vs, tuple):
                return "".join(["%s:%s\n" % (n, vs) for v in vs])
            else:
                return "%s:%s\n" % (n, vs)

        return "".join([represent(name, value) for name, value in headers])

    def get_canonicalized_resource(self):
        """
        Get an S3 resource path.
        """

        # As <http://docs.amazonwebservices.com/AmazonS3/latest/dev/RESTAuthentication.html>
        # says, if there is a subresource (e.g. ?acl), it is included, but other query
        # parameters (e.g. ?prefix=... in a GET Bucket request) are not included.
        # Yes, that makes no sense in terms of either security or consistency.
        resource = self.object_name
        if resource:
            q = resource.find('?')
            if q >= 0:
                # There can be both a subresource and other parameters, for example
                # '?versions&prefix=foo'. "You are in a maze of twisty edge cases..."
                firstparam = resource[q:].partition('&')[0]  # includes the initial '?'
                resource = resource[:q]                      # strip the query
                if '=' not in firstparam:
                    resource += firstparam                   # add back '?subresource' if present

        path = "/"
        if self.bucket is not None:
            path += self.bucket
        if self.bucket is not None and resource:
            if not resource.startswith("/"):
                path += "/"
            path += resource
        elif self.bucket is not None and not path.endswith("/"):
            path += "/"
        return path

    def sign(self, headers):
        """Sign this query using its built in credentials."""
        text = (self.action + "\n" +
                headers.get("Content-MD5", "") + "\n" +
                headers.get("Content-Type", "") + "\n" +
                headers.get("Date", "") + "\n" +
                self.get_canonicalized_amz_headers(headers) +
                self.get_canonicalized_resource())
        return self.creds.sign(text, hash_type="sha1")

    def submit(self, url_context=None):
        """Submit this query.

        @return: A deferred from get_page
        """
        if not url_context:
            url_context = URLContext(
                self.endpoint, self.bucket, self.object_name)
        d = self.get_page(
            url_context.get_url(), method=self.action, postdata=self.data,
            headers=self.get_headers())
        return d.addErrback(s3_error_wrapper)
Example #36
0
class Query(BaseQuery):
    """A query for submission to the S3 service."""

    def __init__(self, bucket=None, object_name=None, data="",
                 content_type=None, metadata={}, amz_headers={}, *args,
                 **kwargs):
        super(Query, self).__init__(*args, **kwargs)
        self.bucket = bucket
        self.object_name = object_name
        self.data = data
        self.content_type = content_type
        self.metadata = metadata
        self.amz_headers = amz_headers
        self.date = datetimeToString()
        if not self.endpoint or not self.endpoint.host:
            self.endpoint = AWSServiceEndpoint(S3_ENDPOINT)
        self.endpoint.set_method(self.action)

    def set_content_type(self):
        """
        Set the content type based on the file extension used in the object
        name.
        """
        if self.object_name and not self.content_type:
            # XXX nothing is currently done with the encoding... we may
            # need to in the future
            self.content_type, encoding = mimetypes.guess_type(
                self.object_name, strict=False)

    def get_headers(self):
        """
        Build the list of headers needed in order to perform S3 operations.
        """
        headers = {"Content-Length": len(self.data),
                   "Content-MD5": calculate_md5(self.data),
                   "Date": self.date}
        for key, value in self.metadata.iteritems():
            headers["x-amz-meta-" + key] = value
        for key, values in self.amz_headers.iteritems():
            if isinstance(values, tuple):
                headers["x-amz-" + key] = ",".join(values)
            else:
                headers["x-amz-" + key] = values
        # Before we check if the content type is set, let's see if we can set
        # it by guessing the the mimetype.
        self.set_content_type()
        if self.content_type is not None:
            headers["Content-Type"] = self.content_type
        if self.creds is not None:
            signature = self.sign(headers)
            headers["Authorization"] = "AWS %s:%s" % (
                self.creds.access_key, signature)
        return headers

    def get_canonicalized_amz_headers(self, headers):
        """
        Get the headers defined by Amazon S3.
        """
        headers = [
            (name.lower(), value) for name, value in headers.iteritems()
            if name.lower().startswith("x-amz-")]
        headers.sort()
        # XXX missing spec implementation:
        # txAWS doesn't currently unfold long headers
        def represent(n, vs):
            if isinstance(vs, tuple):
                return "".join(["%s:%s\n" % (n, vs) for v in vs])
            else:
                return "%s:%s\n" % (n, vs)

        return "".join([represent(name, value) for name, value in headers])

    def get_canonicalized_resource(self):
        """
        Get an S3 resource path.
        """
        path = "/"
        if self.bucket is not None:
            path += self.bucket
        if self.bucket is not None and self.object_name:
            if not self.object_name.startswith("/"):
                path += "/"
            path += self.object_name
        elif self.bucket is not None and not path.endswith("/"):
            path += "/"
        return path

    def sign(self, headers):
        """Sign this query using its built in credentials."""
        text = (self.action + "\n" +
                headers.get("Content-MD5", "") + "\n" +
                headers.get("Content-Type", "") + "\n" +
                headers.get("Date", "") + "\n" +
                self.get_canonicalized_amz_headers(headers) +
                self.get_canonicalized_resource())
        return self.creds.sign(text, hash_type="sha1")

    def submit(self, url_context=None):
        """Submit this query.

        @return: A deferred from get_page
        """
        if not url_context:
            url_context = URLContext(
                self.endpoint, self.bucket, self.object_name)
        d = self.get_page(
            url_context.get_url(), method=self.action, postdata=self.data,
            headers=self.get_headers())
        return d.addErrback(s3_error_wrapper)
Example #37
0
 def test_get_uri_custom_port(self):
     uri = "https://my.service:8080/endpoint"
     endpoint = AWSServiceEndpoint(uri=uri)
     new_uri = endpoint.get_uri()
     self.assertEquals(new_uri, uri)
Example #38
0
 def setUp(self):
     self.endpoint = AWSServiceEndpoint(uri="http://my.service/da_endpoint")
Example #39
0
class AWSServiceEndpointTestCase(TestCase):

    def setUp(self):
        self.endpoint = AWSServiceEndpoint(uri="http://my.service/da_endpoint")

    def test_warning_when_verification_disabled(self):
        """
        L{AWSServiceEndpoint} emits a warning when told not to perform
        certificate verification.
        """
        self.assertWarns(
            UserWarning,
            "Operating with certificate verification disabled!",
            __file__,
            lambda: AWSServiceEndpoint(ssl_hostname_verification=False),
        )

    def test_simple_creation(self):
        endpoint = AWSServiceEndpoint()
        self.assertEquals(endpoint.scheme, "http")
        self.assertEquals(endpoint.host, "")
        self.assertEquals(endpoint.port, None)
        self.assertEquals(endpoint.path, "/")
        self.assertEquals(endpoint.method, "GET")

    def test_custom_method(self):
        endpoint = AWSServiceEndpoint(
            uri="http://service/endpoint", method="PUT")
        self.assertEquals(endpoint.method, "PUT")

    def test_parse_uri(self):
        self.assertEquals(self.endpoint.scheme, "http")
        self.assertEquals(self.endpoint.host, "my.service")
        self.assertIdentical(self.endpoint.port, None)
        self.assertEquals(self.endpoint.path, "/da_endpoint")

    def test_parse_uri_https_and_custom_port(self):
        endpoint = AWSServiceEndpoint(uri="https://my.service:8080/endpoint")
        self.assertEquals(endpoint.scheme, "https")
        self.assertEquals(endpoint.host, "my.service")
        self.assertEquals(endpoint.port, 8080)
        self.assertEquals(endpoint.path, "/endpoint")

    def test_get_uri(self):
        uri = self.endpoint.get_uri()
        self.assertEquals(uri, "http://my.service/da_endpoint")

    def test_get_uri_custom_port(self):
        uri = "https://my.service:8080/endpoint"
        endpoint = AWSServiceEndpoint(uri=uri)
        new_uri = endpoint.get_uri()
        self.assertEquals(new_uri, uri)

    def test_set_host(self):
        self.assertEquals(self.endpoint.host, "my.service")
        self.endpoint.set_host("newhost.com")
        self.assertEquals(self.endpoint.host, "newhost.com")

    def test_get_host(self):
        self.assertEquals(self.endpoint.host, self.endpoint.get_host())

    def test_get_canonical_host(self):
        """
        If the port is not specified the canonical host is the same as
        the host.
        """
        uri = "http://my.service/endpoint"
        endpoint = AWSServiceEndpoint(uri=uri)
        self.assertEquals("my.service", endpoint.get_canonical_host())

    def test_get_canonical_host_with_non_default_port(self):
        """
        If the port is not the default, the canonical host includes it.
        """
        uri = "http://my.service:99/endpoint"
        endpoint = AWSServiceEndpoint(uri=uri)
        self.assertEquals("my.service:99", endpoint.get_canonical_host())

    def test_get_canonical_host_is_lower_case(self):
        """
        The canonical host is guaranteed to be lower case.
        """
        uri = "http://MY.SerVice:99/endpoint"
        endpoint = AWSServiceEndpoint(uri=uri)
        self.assertEquals("my.service:99", endpoint.get_canonical_host())

    def test_set_canonical_host(self):
        """
        The canonical host is converted to lower case.
        """
        endpoint = AWSServiceEndpoint()
        endpoint.set_canonical_host("My.Service")
        self.assertEquals("my.service", endpoint.host)
        self.assertIdentical(None, endpoint.port)

    def test_set_canonical_host_with_port(self):
        """
        The canonical host can optionally have a port.
        """
        endpoint = AWSServiceEndpoint()
        endpoint.set_canonical_host("my.service:99")
        self.assertEquals("my.service", endpoint.host)
        self.assertEquals(99, endpoint.port)

    def test_set_canonical_host_with_empty_port(self):
        """
        The canonical host can also have no port.
        """
        endpoint = AWSServiceEndpoint()
        endpoint.set_canonical_host("my.service:")
        self.assertEquals("my.service", endpoint.host)
        self.assertIdentical(None, endpoint.port)

    def test_set_path(self):
        self.endpoint.set_path("/newpath")
        self.assertEquals(
            self.endpoint.get_uri(),
            "http://my.service/newpath")

    def test_set_method(self):
        self.assertEquals(self.endpoint.method, "GET")
        self.endpoint.set_method("PUT")
        self.assertEquals(self.endpoint.method, "PUT")
Example #40
0
class Query(BaseQuery):
    """A query for submission to the S3 service."""

    def __init__(self, bucket=None, object_name=None, data="",
                 content_type=None, metadata={}, amz_headers={},
                 body_producer=None, *args, **kwargs):
        super(Query, self).__init__(*args, **kwargs)

        # data might be None or "", alas.
        if data and body_producer is not None:
            raise ValueError("data and body_producer are mutually exclusive.")

        self.bucket = bucket
        self.object_name = object_name
        self.data = data
        self.body_producer = body_producer
        self.content_type = content_type
        self.metadata = metadata
        self.amz_headers = amz_headers
        self._date = datetimeToString()
        if not self.endpoint or not self.endpoint.host:
            self.endpoint = AWSServiceEndpoint(S3_ENDPOINT)
        self.endpoint.set_method(self.action)

    @property
    def date(self):
        """
        Return the date and emit a deprecation warning.
        """
        warnings.warn("txaws.s3.client.Query.date is a deprecated attribute",
                      DeprecationWarning,
                      stacklevel=2)
        return self._date

    @date.setter
    def date(self, value):
        """
        Set the date.

        @param value: The new date for this L{Query}.
        @type value: L{str}
        """
        self._date = value

    def set_content_type(self):
        """
        Set the content type based on the file extension used in the object
        name.
        """
        if self.object_name and not self.content_type:
            # XXX nothing is currently done with the encoding... we may
            # need to in the future
            self.content_type, encoding = mimetypes.guess_type(
                self.object_name, strict=False)

    def get_headers(self, instant):
        """
        Build the list of headers needed in order to perform S3 operations.
        """
        headers = {'x-amz-date': _auth_v4.makeAMZDate(instant)}
        if self.body_producer is None:
            data = self.data
            if data is None:
                data = b""
            headers["x-amz-content-sha256"] = hashlib.sha256(data).hexdigest()
        else:
            data = None
            headers["x-amz-content-sha256"] = b"UNSIGNED-PAYLOAD"
        for key, value in self.metadata.iteritems():
            headers["x-amz-meta-" + key] = value
        for key, value in self.amz_headers.iteritems():
            headers["x-amz-" + key] = value

        # Before we check if the content type is set, let's see if we can set
        # it by guessing the the mimetype.
        self.set_content_type()
        if self.content_type is not None:
            headers["Content-Type"] = self.content_type
        if self.creds is not None:
            headers["Authorization"] = self.sign(
                headers,
                data,
                s3_url_context(self.endpoint, self.bucket, self.object_name),
                instant,
                method=self.action)
        return headers

    def sign(self, headers, data, url_context, instant, method,
             region=REGION_US_EAST_1):
        """Sign this query using its built in credentials."""
        headers["host"] = url_context.get_encoded_host()

        if data is None:
            request = _auth_v4._CanonicalRequest.from_request_components(
                method=method,
                url=url_context.get_encoded_path(),
                headers=headers,
                headers_to_sign=('host', 'x-amz-date'),
                payload_hash=None,
            )
        else:
            request = _auth_v4._CanonicalRequest.from_request_components_and_payload(
                method=method,
                url=url_context.get_encoded_path(),
                headers=headers,
                headers_to_sign=('host', 'x-amz-date'),
                payload=data,
            )

        return _auth_v4._make_authorization_header(
            region=region,
            service="s3",
            canonical_request=request,
            credentials=self.creds,
            instant=instant)

    def submit(self, url_context=None, utcnow=datetime.datetime.utcnow):
        """Submit this query.

        @return: A deferred from get_page
        """
        if not url_context:
            url_context = s3_url_context(
                self.endpoint, self.bucket, self.object_name)
        d = self.get_page(
            url_context.get_encoded_url(),
            method=self.action,
            postdata=self.data or b"",
            headers=self.get_headers(utcnow()),
        )

        return d.addErrback(s3_error_wrapper)