def test_ssl_hostname_verification(self): """ If the endpoint passed to L{BaseQuery} has C{ssl_hostname_verification} sets to C{True}, a L{VerifyingContextFactory} is passed to C{connectSSL}. """ class FakeReactor(object): def __init__(self): self.connects = [] def connectSSL(self, host, port, client, factory): self.connects.append((host, port, client, factory)) certs = makeCertificate(O="Test Certificate", CN="something")[1] self.patch(ssl, "_ca_certs", certs) fake_reactor = FakeReactor() endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) query = BaseQuery("an action", "creds", endpoint, fake_reactor) query.get_page("https://example.com/file") [(host, port, client, factory)] = fake_reactor.connects self.assertEqual("example.com", host) self.assertEqual(443, port) self.assertTrue(isinstance(factory, ssl.VerifyingContextFactory)) self.assertEqual("example.com", factory.host) self.assertNotEqual([], factory.caCerts)
def test_errors(self): query = BaseQuery( "an action", "creds", AWSServiceEndpoint("http://endpoint"), ) d = query.get_page(self._get_url("not_there")) self.assertFailure(d, TwistedWebError) return d
def test_ssl_hostname_verification(self): """ If the endpoint passed to L{BaseQuery} has C{ssl_hostname_verification} sets to C{True}, a L{VerifyingContextFactory} is passed to C{connectSSL}. """ agent_creations = [] @implementer(IAgent) class FakeAgent(object): def __init__(self, reactor, contextFactory, connectTimeout=None, bindAddress=None, pool=None): agent_creations.append((reactor, contextFactory, connectTimeout, bindAddress, pool)) def request(self, method, uri, headers=None, bodyProducer=None): return Deferred() verifyClass(IAgent, FakeAgent) certs = [makeCertificate(O="Test Certificate", CN="something")[1]] self.patch(base, "Agent", FakeAgent) self.patch(ssl, "_ca_certs", certs) endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) query = BaseQuery("an action", "creds", endpoint, reactor="ignored") query.get_page("https://example.com/file") self.assertEqual(len(agent_creations), 1) [(_, contextFactory, _, _, _)] = agent_creations self.assertIsInstance(contextFactory, ssl.VerifyingContextFactory)
def test_ssl_verification_negative(self): """ The L{VerifyingContextFactory} fails with a SSL error the certificates can't be checked. """ context_factory = WebDefaultOpenSSLContextFactory( BADPRIVKEY, BADPUBKEY) self.port = reactor.listenSSL(0, self.site, context_factory, interface="127.0.0.1") self.portno = self.port.getHost().port endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) query = BaseQuery("an action", "creds", endpoint) d = query.get_page(self._get_url("file")) def fail(ignore): self.fail('Expected SSLError') def check_exception(why): # XXX kind of a mess here ... need to unwrap the # exception and check root_exc = why.value[0][0].value self.assert_(isinstance(root_exc, SSLError)) return d.addCallbacks(fail, check_exception)
def test_get_page(self): query = BaseQuery( "an action", "creds", AWSServiceEndpoint("http://endpoint"), ) d = query.get_page(self._get_url("file")) d.addCallback(self.assertEquals, "0123456789") return d
def test_get_request_headers_with_client(self): def check_results(results): self.assertEquals(results.keys(), []) self.assertEquals(results.values(), []) query = BaseQuery("an action", "creds", "http://endpoint") d = query.get_page(self._get_url("file")) d.addCallback(query.get_request_headers) return d.addCallback(check_results)
def test_custom_body_producer(self): def check_producer_was_used(ignore): self.assertEqual(producer.written, 'test data') producer = StringBodyProducer('test data') query = BaseQuery("an action", "creds", "http://endpoint", body_producer=producer) d = query.get_page(self._get_url("thing_to_put"), method='PUT') return d.addCallback(check_producer_was_used)
def test_get_response_headers_with_client(self): def check_results(results): self.assertEquals(sorted(results.keys()), [ "accept-ranges", "content-length", "content-type", "date", "last-modified", "server" ]) self.assertEquals(len(results.values()), 6) query = BaseQuery("an action", "creds", "http://endpoint") d = query.get_page(self._get_url("file")) d.addCallback(query.get_response_headers) return d.addCallback(check_results)
def test_get_response_headers_with_client(self): def check_results(results): self.assertEquals( sorted(results.keys()), ["accept-ranges", "content-length", "content-type", "date", "last-modified", "server"], ) self.assertEquals(len(results.values()), 6) query = BaseQuery("an action", "creds", "http://endpoint") d = query.get_page(self._get_url("file")) d.addCallback(query.get_response_headers) return d.addCallback(check_results)
def test_ssl_verification_negative(self): """ The L{VerifyingContextFactory} fails with a SSL error the certificates can't be checked. """ context_factory = DefaultOpenSSLContextFactory(BADPRIVKEY, BADPUBKEY) self.port = reactor.listenSSL(0, self.site, context_factory, interface="127.0.0.1") self.portno = self.port.getHost().port endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) query = BaseQuery("an action", "creds", endpoint) d = query.get_page(self._get_url("file")) return self.assertFailure(d, SSLError)
def test_ssl_verification_positive(self): """ The L{VerifyingContextFactory} properly allows to connect to the endpoint if the certificates match. """ context_factory = DefaultOpenSSLContextFactory(PRIVKEY, PUBKEY) self.port = reactor.listenSSL(0, self.site, context_factory, interface="127.0.0.1") self.portno = self.port.getHost().port endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) query = BaseQuery("an action", "creds", endpoint) d = query.get_page(self._get_url("file")) return d.addCallback(self.assertEquals, "0123456789")
def test_custom_body_producer(self): def check_producer_was_used(ignore): self.assertEqual(producer.written, 'test data') producer = StringBodyProducer('test data') query = BaseQuery( "an action", "creds", AWSServiceEndpoint("http://endpoint"), body_producer=producer, ) d = query.get_page(self._get_url("thing_to_put"), method='PUT') return d.addCallback(check_producer_was_used)
def test_ssl_subject_alt_name(self): """ L{VerifyingContextFactory} supports checking C{subjectAltName} in the certificate if it's available. """ context_factory = DefaultOpenSSLContextFactory(PRIVSANKEY, PUBSANKEY) self.port = reactor.listenSSL(0, self.site, context_factory, interface="127.0.0.1") self.portno = self.port.getHost().port endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) query = BaseQuery("an action", "creds", endpoint) d = query.get_page("https://127.0.0.1:%d/file" % (self.portno,)) return d.addCallback(self.assertEquals, "0123456789")
def test_ssl_verification_bypassed(self): """ L{BaseQuery} doesn't use L{VerifyingContextFactory} if C{ssl_hostname_verification} is C{False}, thus allowing to connect to non-secure endpoints. """ context_factory = DefaultOpenSSLContextFactory(BADPRIVKEY, BADPUBKEY) self.port = reactor.listenSSL(0, self.site, context_factory, interface="127.0.0.1") self.portno = self.port.getHost().port endpoint = AWSServiceEndpoint(ssl_hostname_verification=False) query = BaseQuery("an action", "creds", endpoint) d = query.get_page(self._get_url("file")) return d.addCallback(self.assertEquals, "0123456789")
def test_ssl_subject_alt_name(self): """ L{VerifyingContextFactory} supports checking C{subjectAltName} in the certificate if it's available. """ context_factory = DefaultOpenSSLContextFactory(PRIVSANKEY, PUBSANKEY) self.port = reactor.listenSSL(0, self.site, context_factory, interface="127.0.0.1") self.portno = self.port.getHost().port endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) query = BaseQuery("an action", "creds", endpoint) d = query.get_page("https://127.0.0.1:%d/file" % (self.portno, )) return d.addCallback(self.assertEquals, "0123456789")
def test_ssl_verification_positive(self): """ The L{VerifyingContextFactory} properly allows to connect to the endpoint if the certificates match. """ context_factory = WebDefaultOpenSSLContextFactory(PRIVKEY, PUBKEY) self.port = reactor.listenSSL(0, self.site, context_factory, interface="127.0.0.1") self.portno = self.port.getHost().port endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) query = BaseQuery("an action", "creds", endpoint) d = query.get_page(self._get_url("file")) return d.addCallback(self.assertEquals, "0123456789")
def test_get_response_headers_with_client(self): def check_results(results): #self.assertEquals(sorted(results.keys()), [ # "accept-ranges", "content-length", "content-type", "date", # "last-modified", "server"]) # XXX I think newclient exludes content-length from headers? # Also the header names are capitalized ... do we need to worry # about backwards compat? self.assertEquals(sorted(results.keys()), [ "Accept-Ranges", "Content-Type", "Date", "Last-Modified", "Server"]) self.assertEquals(len(results.values()), 5) query = BaseQuery("an action", "creds", "http://endpoint") d = query.get_page(self._get_url("file")) d.addCallback(query.get_response_headers) return d.addCallback(check_results)
def test_custom_receiver_factory(self): class TestReceiverProtocol(StreamingBodyReceiver): used = False def __init__(self): StreamingBodyReceiver.__init__(self) TestReceiverProtocol.used = True def check_used(ignore): self.assert_(TestReceiverProtocol.used) query = BaseQuery("an action", "creds", "http://endpoint", receiver_factory=TestReceiverProtocol) d = query.get_page(self._get_url("file")) d.addCallback(self.assertEquals, "0123456789") d.addCallback(check_used) return d
def test_custom_receiver_factory(self): class TestReceiverProtocol(StreamingBodyReceiver): used = False def __init__(self): StreamingBodyReceiver.__init__(self) TestReceiverProtocol.used = True def check_used(ignore): self.assert_(TestReceiverProtocol.used) query = BaseQuery( "an action", "creds", AWSServiceEndpoint("http://endpoint"), receiver_factory=TestReceiverProtocol, ) d = query.get_page(self._get_url("file")) d.addCallback(self.assertEquals, "0123456789") d.addCallback(check_used) return d
def test_ssl_verification_negative(self): """ The L{VerifyingContextFactory} fails with a SSL error the certificates can't be checked. """ context_factory = WebDefaultOpenSSLContextFactory(BADPRIVKEY, BADPUBKEY) self.port = reactor.listenSSL( 0, self.site, context_factory, interface="127.0.0.1") self.portno = self.port.getHost().port endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) query = BaseQuery("an action", "creds", endpoint) d = query.get_page(self._get_url("file")) def fail(ignore): self.fail('Expected SSLError') def check_exception(why): # XXX kind of a mess here ... need to unwrap the # exception and check root_exc = why.value[0][0].value self.assert_(isinstance(root_exc, SSLError)) return d.addCallbacks(fail, check_exception)
def test_get_response_headers_with_client(self): def check_results(results): #self.assertEquals(sorted(results.keys()), [ # "accept-ranges", "content-length", "content-type", "date", # "last-modified", "server"]) # XXX I think newclient exludes content-length from headers? # Also the header names are capitalized ... do we need to worry # about backwards compat? self.assertEquals(sorted(results.keys()), [ "Accept-Ranges", "Content-Type", "Date", "Last-Modified", "Server" ]) self.assertEquals(len(results.values()), 5) query = BaseQuery( "an action", "creds", AWSServiceEndpoint("http://endpoint"), ) d = query.get_page(self._get_url("file")) d.addCallback(query.get_response_headers) return d.addCallback(check_results)
def test_ssl_hostname_verification(self): """ If the endpoint passed to L{BaseQuery} has C{ssl_hostname_verification} sets to C{True}, a L{VerifyingContextFactory} is passed to C{connectSSL}. """ class FakeReactor(object): def __init__(self): self.connects = [] def connectSSL(self, host, port, client, factory): self.connects.append((host, port, client, factory)) fake_reactor = FakeReactor() endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) query = BaseQuery("an action", "creds", endpoint, fake_reactor) query.get_page("https://example.com/file") [(host, port, client, factory)] = fake_reactor.connects self.assertEqual("example.com", host) self.assertEqual(443, port) self.assertTrue(isinstance(factory, VerifyingContextFactory)) self.assertEqual("example.com", factory.host) self.assertNotEqual([], factory.caCerts)
def test_creation(self): query = BaseQuery("an action", "creds", "http://endpoint") self.assertEquals(query.action, "an action") self.assertEquals(query.creds, "creds") self.assertEquals(query.endpoint, "http://endpoint")
def test_get_response_headers_no_client(self): query = BaseQuery("an action", "creds", "http://endpoint") results = query.get_response_headers() self.assertEquals(results, None)